khalidsaifullaah commited on
Commit
3df3a47
1 Parent(s): a8e4fc0

CC12M downloader script added

Browse files
Files changed (1) hide show
  1. data/CC12M_downloader.py +91 -0
data/CC12M_downloader.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)
2
+
3
+ #%%
4
+ import sys
5
+ import os
6
+ from datetime import datetime
7
+ import pandas as pd
8
+ import contexttimer
9
+ from urllib.request import urlopen
10
+ import requests
11
+ from PIL import Image
12
+ import torch
13
+ from torchvision.transforms import functional as TF
14
+ from multiprocessing import Pool
15
+ from tqdm import tqdm
16
+ import logging
17
+
18
+ # Setup
19
+ logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
20
+ requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)
21
+
22
+
23
+ # # For downloading SVG images (I can't get this to work)
24
+ # from io import BytesIO
25
+ # import cairosvg
26
+
27
+ #%%
28
+ # Load data
29
+ print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
30
+ with contexttimer.Timer(prefix="Loading from tsv"):
31
+ df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)
32
+
33
+ url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
34
+ print(f'Loaded {len(url_to_idx_map)} urls')
35
+
36
+ #%%
37
+ df.head()
38
+
39
+ #%%
40
+
41
+ # Note: it seems that there are no SVG images
42
+ df.sample(10000)[1].str.contains('.svg').sum()
43
+
44
+ #%%
45
+ # Resize function
46
+ def resize(img):
47
+ max_size_of_short_side = 512
48
+ if min(img.size) > max_size_of_short_side:
49
+ img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS)
50
+ return img
51
+
52
+ base_dir = os.path.join(os.getcwd(), 'images')
53
+
54
+ def process(item):
55
+ url, image_id = item
56
+ try:
57
+ base_url = os.path.basename(url) # extract base url
58
+ stem, ext = os.path.splitext(base_url) # split into stem and extension
59
+ filename = f'{image_id:08d}---{stem}.jpg' # create filename
60
+ filepath = os.path.join(base_dir, filename) # concat to get filepath
61
+ if not os.path.isfile(filepath):
62
+ # if filepath.endswith('.svg'):
63
+ # raise NotImplementedError()
64
+ # image_bytes = BytesIO() # create a bytestream
65
+ # cairosvg.svg2png(url=url, write_to=image_bytes) # convert svg into image
66
+ # else:
67
+ req = requests.get(url, stream=True, timeout=1, verify=False).raw
68
+ image = Image.open(req).convert('RGB')
69
+ if min(image.size) > 512:
70
+ image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
71
+ # image = resize(image) # resize PIL image
72
+ image.save(filepath) # save PIL image
73
+ except Exception as e:
74
+ logging.info(" ".join(repr(e).splitlines()))
75
+ logging.error(url)
76
+
77
+ #%%
78
+ #for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
79
+ # process(item)
80
+ # if i > 100:
81
+ # break
82
+
83
+ # Use multiprocessing for speed
84
+ list_of_items = list(url_to_idx_map.items())
85
+ print(len(list_of_items))
86
+ list_of_items = list_of_items[10_000_000:]
87
+ print(len(list_of_items))
88
+ with Pool(128) as p:
89
+ r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
90
+ print('DONE')
91
+