rsanjaykamath commited on
Commit
7fc7f3d
1 Parent(s): eb43f71
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .DS_Store +0 -0
  2. .idea/.gitignore +3 -0
  3. LICENSE.txt +12 -0
  4. README 2.md +46 -0
  5. README.md +40 -6
  6. __pycache__/run_code.cpython-38.pyc +0 -0
  7. app.py +232 -0
  8. app_run.ipynb +400 -0
  9. configs/caption_coco.yaml +33 -0
  10. configs/med_config.json +21 -0
  11. configs/nlvr.yaml +21 -0
  12. configs/nocaps.yaml +15 -0
  13. configs/pretrain.yaml +27 -0
  14. configs/retrieval_coco.yaml +34 -0
  15. configs/retrieval_flickr.yaml +34 -0
  16. configs/vqa.yaml +25 -0
  17. data/__init__.py +101 -0
  18. data/coco_karpathy_dataset.py +126 -0
  19. data/flickr30k_dataset.py +93 -0
  20. data/nlvr_dataset.py +78 -0
  21. data/nocaps_dataset.py +32 -0
  22. data/pretrain_dataset.py +59 -0
  23. data/utils.py +112 -0
  24. data/vqa_dataset.py +88 -0
  25. elephant.jpg +0 -0
  26. eval_nocaps.py +118 -0
  27. examples/ex1.jpg +0 -0
  28. examples/ex2.jpg +0 -0
  29. examples/ex3.jpg +0 -0
  30. extras/.DS_Store +0 -0
  31. extras/sample-images/0.JPG +0 -0
  32. extras/sample-images/1.JPG +0 -0
  33. extras/sample-images/10.jpg +0 -0
  34. extras/sample-images/2.jpg +0 -0
  35. extras/sample-images/3.jpg +0 -0
  36. extras/sample-images/4.jpg +0 -0
  37. extras/sample-images/5.jpg +0 -0
  38. extras/sample-images/6.JPG +0 -0
  39. extras/sample-images/7.JPG +0 -0
  40. extras/sample-images/8.jpg +0 -0
  41. extras/sample-images/9.jpg +0 -0
  42. foo.png +0 -0
  43. gradio_cached_examples/log.csv +2 -0
  44. local_run.ipynb +347 -0
  45. model-data/.DS_Store +0 -0
  46. model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 +3 -0
  47. model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 +3 -0
  48. model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 +3 -0
  49. model-data/weights/readme.md +1 -0
  50. modelsn/__init__.py +0 -0
.DS_Store ADDED
Binary file (10.2 kB). View file
 
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ # Default ignored files
3
+ /workspace.xml
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README 2.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PPE_Detection
3
+ emoji: 💩
4
+ colorFrom: pink
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ app_file: app.py
8
+ pinned: false
9
+ license: other
10
+ ---
11
+
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio`, `streamlit`, or `static`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `models`: _List[string]_
38
+ HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
39
+ Will be parsed automatically from your code if not specified here.
40
+
41
+ `datasets`: _List[string]_
42
+ HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
43
+ Will be parsed automatically from your code if not specified here.
44
+
45
+ `pinned`: _boolean_
46
+ Whether the Space stays on top of your list.
README.md CHANGED
@@ -1,12 +1,46 @@
1
  ---
2
- title: Safeworld_Captioning_Spaces
3
- emoji: 📚
4
- colorFrom: indigo
5
- colorTo: purple
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
- license: other
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: BLIP
3
+ emoji: 🦀
4
+ colorFrom: red
5
+ colorTo: blue
6
  sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
+ license: bsd-3-clause
10
  ---
11
 
12
+ # Configuration
13
+
14
+ `title`: _string_
15
+ Display title for the Space
16
+
17
+ `emoji`: _string_
18
+ Space emoji (emoji-only character allowed)
19
+
20
+ `colorFrom`: _string_
21
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
22
+
23
+ `colorTo`: _string_
24
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
25
+
26
+ `sdk`: _string_
27
+ Can be either `gradio`, `streamlit`, or `static`
28
+
29
+ `sdk_version` : _string_
30
+ Only applicable for `streamlit` SDK.
31
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
32
+
33
+ `app_file`: _string_
34
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
35
+ Path is relative to the root of the repository.
36
+
37
+ `models`: _List[string]_
38
+ HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
39
+ Will be parsed automatically from your code if not specified here.
40
+
41
+ `datasets`: _List[string]_
42
+ HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
43
+ Will be parsed automatically from your code if not specified here.
44
+
45
+ `pinned`: _boolean_
46
+ Whether the Space stays on top of your list.
__pycache__/run_code.cpython-38.pyc ADDED
Binary file (3.18 kB). View file
 
app.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system(
4
+ "wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg")
5
+
6
+ from PIL import Image
7
+ import requests
8
+ import torch
9
+ from torchvision import transforms
10
+ from torchvision.transforms.functional import InterpolationMode
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ # MDETR Code
15
+ import torchvision.transforms as T
16
+ import matplotlib.pyplot as plt
17
+ from collections import defaultdict
18
+ import torch.nn.functional as F
19
+ import numpy as np
20
+ from skimage.measure import find_contours
21
+
22
+ from matplotlib import patches, lines
23
+ from matplotlib.patches import Polygon
24
+ import gradio as gr
25
+
26
+ torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg',
27
+ 'elephant.jpg')
28
+
29
+ model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True,
30
+ return_postprocessor=True)
31
+ model2 = model2.cpu()
32
+ model2.eval()
33
+
34
+ torch.set_grad_enabled(False);
35
+ # standard PyTorch mean-std input image normalization
36
+ transform = T.Compose([
37
+ T.Resize(800),
38
+ T.ToTensor(),
39
+ T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
40
+ ])
41
+
42
+
43
+ # for output bounding box post-processing
44
+ def box_cxcywh_to_xyxy(x):
45
+ x_c, y_c, w, h = x.unbind(1)
46
+ b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
47
+ (x_c + 0.5 * w), (y_c + 0.5 * h)]
48
+ return torch.stack(b, dim=1)
49
+
50
+
51
+ def rescale_bboxes(out_bbox, size):
52
+ img_w, img_h = size
53
+ b = box_cxcywh_to_xyxy(out_bbox)
54
+ b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
55
+ return b
56
+
57
+
58
+ # colors for visualization
59
+ COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
60
+ [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
61
+
62
+
63
+ def apply_mask(image, mask, color, alpha=0.5):
64
+ """Apply the given mask to the image.
65
+ """
66
+ for c in range(3):
67
+ image[:, :, c] = np.where(mask == 1,
68
+ image[:, :, c] *
69
+ (1 - alpha) + alpha * color[c] * 255,
70
+ image[:, :, c])
71
+ return image
72
+
73
+
74
+ def plot_results(pil_img, scores, boxes, labels, masks=None):
75
+ plt.figure(figsize=(16, 10))
76
+ np_image = np.array(pil_img)
77
+ ax = plt.gca()
78
+ colors = COLORS * 100
79
+ if masks is None:
80
+ masks = [None for _ in range(len(scores))]
81
+ assert len(scores) == len(boxes) == len(labels) == len(masks)
82
+ for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):
83
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
84
+ fill=False, color=c, linewidth=3))
85
+ text = f'{l}: {s:0.2f}'
86
+ ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
87
+
88
+ if mask is None:
89
+ continue
90
+ np_image = apply_mask(np_image, mask, c)
91
+
92
+ padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
93
+ padded_mask[1:-1, 1:-1] = mask
94
+ contours = find_contours(padded_mask, 0.5)
95
+ for verts in contours:
96
+ # Subtract the padding and flip (y, x) to (x, y)
97
+ verts = np.fliplr(verts) - 1
98
+ p = Polygon(verts, facecolor="none", edgecolor=c)
99
+ ax.add_patch(p)
100
+
101
+ plt.imshow(np_image)
102
+ plt.axis('off')
103
+ plt.savefig('foo.png', bbox_inches='tight')
104
+ return 'foo.png'
105
+
106
+
107
+ def add_res(results, ax, color='green'):
108
+ # for tt in results.values():
109
+ if True:
110
+ bboxes = results['boxes']
111
+ labels = results['labels']
112
+ scores = results['scores']
113
+ # keep = scores >= 0.0
114
+ # bboxes = bboxes[keep].tolist()
115
+ # labels = labels[keep].tolist()
116
+ # scores = scores[keep].tolist()
117
+ # print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))
118
+
119
+ colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']
120
+
121
+ for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):
122
+ ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))
123
+ cls_name = ll if isinstance(ll, str) else CLASSES[ll]
124
+ text = f'{cls_name}: {ss:.2f}'
125
+ print(text)
126
+ ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
127
+
128
+
129
+ def plot_inference(im, caption, approaches):
130
+ choices = {"Worker Helmet Separately": 1, "Worker Helmet Vest": 2, "Workers only": 3}
131
+
132
+ # mean-std normalize the input image (batch-size: 1)
133
+ img = transform(im).unsqueeze(0).cpu()
134
+
135
+ # propagate through the model
136
+ memory_cache = model2(img, [caption], encode_and_save=True)
137
+ outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache)
138
+
139
+ # keep only predictions with 0.7+ confidence
140
+ probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()
141
+ keep = (probas > 0.7).cpu()
142
+
143
+ # convert boxes from [0; 1] to image scales
144
+ bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)
145
+
146
+ # Extract the text spans predicted by each box
147
+ positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()
148
+ predicted_spans = defaultdict(str)
149
+ for tok in positive_tokens:
150
+ item, pos = tok
151
+ if pos < 255:
152
+ span = memory_cache["tokenized"].token_to_chars(0, pos)
153
+ predicted_spans[item] += " " + caption[span.start:span.end]
154
+
155
+ labels = [predicted_spans[k] for k in sorted(list(predicted_spans.keys()))]
156
+ caption = 'Caption: ' + caption
157
+ return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches]))
158
+
159
+
160
+ # BLIP Code
161
+
162
+
163
+ from modelsn.blip import blip_decoder
164
+
165
+ image_size = 384
166
+ transform = transforms.Compose([
167
+ transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
168
+ transforms.ToTensor(),
169
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
170
+ ])
171
+
172
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
173
+
174
+ model = blip_decoder(pretrained=model_url, image_size=384, vit='base')
175
+ model.eval()
176
+ model = model.to(device)
177
+
178
+ from modelsn.blip_vqa import blip_vqa
179
+
180
+ image_size_vq = 480
181
+ transform_vq = transforms.Compose([
182
+ transforms.Resize((image_size_vq, image_size_vq), interpolation=InterpolationMode.BICUBIC),
183
+ transforms.ToTensor(),
184
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
185
+ ])
186
+
187
+ model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
188
+
189
+ model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
190
+ model_vq.eval()
191
+ model_vq = model_vq.to(device)
192
+
193
+
194
+ def inference(raw_image, approaches, question):
195
+ image = transform(raw_image).unsqueeze(0).to(device)
196
+ with torch.no_grad():
197
+ caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
198
+
199
+ return (plot_inference(raw_image, caption[0], approaches))
200
+ # return 'caption: '+caption[0]
201
+
202
+
203
+ # PPE Detection code
204
+ import numpy as np
205
+ import run_code
206
+ import gradio as gr
207
+
208
+
209
+ def sepia_call(caption, Input_Image, MDETR_im, Approach):
210
+ pil_image = Input_Image
211
+ open_cv_image = np.asarray(pil_image)
212
+ sepia_img = run_code.run(open_cv_image, Approach)
213
+ images = sepia_img['img']
214
+ texts = sepia_img['text']
215
+
216
+ return (caption, MDETR_im, images, texts)
217
+
218
+
219
+ inputs = [gr.inputs.Image(type='pil'),
220
+ gr.inputs.Radio(choices=["Worker Helmet Separately", "Worker Helmet Vest", "Workers only"], type="value",
221
+ default="Worker Helmet Vest", label="Model"), "textbox"]
222
+ outputs = [gr.outputs.Textbox(label="Output"), "image", "image", gr.outputs.Textbox(label="Output")]
223
+
224
+ title = "BLIP + MDETR + PPE Detection"
225
+
226
+ description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
227
+
228
+ article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
229
+
230
+ gr.Interface(inference, inputs, outputs, title=title, description=description, article=article,
231
+ examples=[['starry.jpg', "Image Captioning", "None"]]).launch(share=True, enable_queue=True,
232
+ cache_examples=False)
app_run.ipynb ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 3,
6
+ "id": "15468c81",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "--2022-02-15 18:26:17-- https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\n",
14
+ "Resolving upload.wikimedia.org (upload.wikimedia.org)... 91.198.174.208\n",
15
+ "Connecting to upload.wikimedia.org (upload.wikimedia.org)|91.198.174.208|:443... connected.\n",
16
+ "HTTP request sent, awaiting response... 200 OK\n",
17
+ "Length: 1388211 (1.3M) [image/jpeg]\n",
18
+ "Saving to: ‘starry.jpg’\n",
19
+ "\n",
20
+ " 0K .......... .......... .......... .......... .......... 3% 776K 2s\n",
21
+ " 50K .......... .......... .......... .......... .......... 7% 877K 2s\n",
22
+ " 100K .......... .......... .......... .......... .......... 11% 2.93M 1s\n",
23
+ " 150K .......... .......... .......... .......... .......... 14% 2.28M 1s\n",
24
+ " 200K .......... .......... .......... .......... .......... 18% 4.04M 1s\n",
25
+ " 250K .......... .......... .......... .......... .......... 22% 5.46M 1s\n",
26
+ " 300K .......... .......... .......... .......... .......... 25% 6.40M 1s\n",
27
+ " 350K .......... .......... .......... .......... .......... 29% 2.41M 0s\n",
28
+ " 400K .......... .......... .......... .......... .......... 33% 3.18M 0s\n",
29
+ " 450K .......... .......... .......... .......... .......... 36% 3.03M 0s\n",
30
+ " 500K .......... .......... .......... .......... .......... 40% 8.30M 0s\n",
31
+ " 550K .......... .......... .......... .......... .......... 44% 3.31M 0s\n",
32
+ " 600K .......... .......... .......... .......... .......... 47% 3.10M 0s\n",
33
+ " 650K .......... .......... .......... .......... .......... 51% 12.3M 0s\n",
34
+ " 700K .......... .......... .......... .......... .......... 55% 4.20M 0s\n",
35
+ " 750K .......... .......... .......... .......... .......... 59% 1.93M 0s\n",
36
+ " 800K .......... .......... .......... .......... .......... 62% 6.28M 0s\n",
37
+ " 850K .......... .......... .......... .......... .......... 66% 3.09M 0s\n",
38
+ " 900K .......... .......... .......... .......... .......... 70% 22.7M 0s\n",
39
+ " 950K .......... .......... .......... .......... .......... 73% 4.43M 0s\n",
40
+ " 1000K .......... .......... .......... .......... .......... 77% 4.16M 0s\n",
41
+ " 1050K .......... .......... .......... .......... .......... 81% 2.29M 0s\n",
42
+ " 1100K .......... .......... .......... .......... .......... 84% 1.81M 0s\n",
43
+ " 1150K .......... .......... .......... .......... .......... 88% 6.20M 0s\n",
44
+ " 1200K .......... .......... .......... .......... .......... 92% 2.03M 0s\n",
45
+ " 1250K .......... .......... .......... .......... .......... 95% 23.5M 0s\n",
46
+ " 1300K .......... .......... .......... .......... .......... 99% 5.04M 0s\n",
47
+ " 1350K ..... 100% 9.95M=0.5s\n",
48
+ "\n",
49
+ "2022-02-15 18:26:17 (2.89 MB/s) - ‘starry.jpg’ saved [1388211/1388211]\n",
50
+ "\n"
51
+ ]
52
+ },
53
+ {
54
+ "data": {
55
+ "application/vnd.jupyter.widget-view+json": {
56
+ "model_id": "02b7655f0b2b404b952b7c152a3a1661",
57
+ "version_major": 2,
58
+ "version_minor": 0
59
+ },
60
+ "text/plain": [
61
+ " 0%| | 0.00/262k [00:00<?, ?B/s]"
62
+ ]
63
+ },
64
+ "metadata": {},
65
+ "output_type": "display_data"
66
+ },
67
+ {
68
+ "name": "stderr",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "Using cache found in /Users/sanjaykamath/.cache/torch/hub/ashkamath_mdetr_main\n",
72
+ "Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight']\n",
73
+ "- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
74
+ "- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
75
+ ]
76
+ },
77
+ {
78
+ "name": "stdout",
79
+ "output_type": "stream",
80
+ "text": [
81
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
82
+ "load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth\n",
83
+ "Running on local URL: http://127.0.0.1:7862/\n",
84
+ "Running on public URL: https://13389.gradio.app\n",
85
+ "\n",
86
+ "This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)\n"
87
+ ]
88
+ },
89
+ {
90
+ "data": {
91
+ "text/html": [
92
+ "\n",
93
+ " <iframe\n",
94
+ " width=\"900\"\n",
95
+ " height=\"500\"\n",
96
+ " src=\"https://13389.gradio.app\"\n",
97
+ " frameborder=\"0\"\n",
98
+ " allowfullscreen\n",
99
+ " \n",
100
+ " ></iframe>\n",
101
+ " "
102
+ ],
103
+ "text/plain": [
104
+ "<IPython.lib.display.IFrame at 0x7fce90855f40>"
105
+ ]
106
+ },
107
+ "metadata": {},
108
+ "output_type": "display_data"
109
+ },
110
+ {
111
+ "data": {
112
+ "text/plain": [
113
+ "(<fastapi.applications.FastAPI at 0x7fcfa3376fd0>,\n",
114
+ " 'http://127.0.0.1:7862/',\n",
115
+ " 'https://13389.gradio.app')"
116
+ ]
117
+ },
118
+ "execution_count": 3,
119
+ "metadata": {},
120
+ "output_type": "execute_result"
121
+ },
122
+ {
123
+ "name": "stderr",
124
+ "output_type": "stream",
125
+ "text": [
126
+ "2022-02-15 18:27:19.011924: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
127
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
128
+ ]
129
+ }
130
+ ],
131
+ "source": [
132
+ "import os\n",
133
+ "os.system(\"wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg\")\n",
134
+ "\n",
135
+ "from PIL import Image\n",
136
+ "import requests\n",
137
+ "import torch\n",
138
+ "from torchvision import transforms\n",
139
+ "from torchvision.transforms.functional import InterpolationMode\n",
140
+ "\n",
141
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
142
+ "\n",
143
+ "\n",
144
+ "\n",
145
+ " \n",
146
+ "#MDETR Code \n",
147
+ "import torchvision.transforms as T\n",
148
+ "import matplotlib.pyplot as plt\n",
149
+ "from collections import defaultdict\n",
150
+ "import torch.nn.functional as F\n",
151
+ "import numpy as np\n",
152
+ "from skimage.measure import find_contours\n",
153
+ "\n",
154
+ "from matplotlib import patches, lines\n",
155
+ "from matplotlib.patches import Polygon\n",
156
+ "import gradio as gr\n",
157
+ "\n",
158
+ "torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg', 'elephant.jpg')\n",
159
+ "\n",
160
+ "\n",
161
+ "model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True, return_postprocessor=True)\n",
162
+ "model2 = model2.cpu()\n",
163
+ "model2.eval()\n",
164
+ "\n",
165
+ "\n",
166
+ "\n",
167
+ "\n",
168
+ "torch.set_grad_enabled(False);\n",
169
+ "# standard PyTorch mean-std input image normalization\n",
170
+ "transform = T.Compose([\n",
171
+ " T.Resize(800),\n",
172
+ " T.ToTensor(),\n",
173
+ " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
174
+ "])\n",
175
+ "\n",
176
+ "# for output bounding box post-processing\n",
177
+ "def box_cxcywh_to_xyxy(x):\n",
178
+ " x_c, y_c, w, h = x.unbind(1)\n",
179
+ " b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n",
180
+ " (x_c + 0.5 * w), (y_c + 0.5 * h)]\n",
181
+ " return torch.stack(b, dim=1)\n",
182
+ "\n",
183
+ "def rescale_bboxes(out_bbox, size):\n",
184
+ " img_w, img_h = size\n",
185
+ " b = box_cxcywh_to_xyxy(out_bbox)\n",
186
+ " b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)\n",
187
+ " return b\n",
188
+ "# colors for visualization\n",
189
+ "COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],\n",
190
+ " [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]\n",
191
+ "\n",
192
+ "def apply_mask(image, mask, color, alpha=0.5):\n",
193
+ " \"\"\"Apply the given mask to the image.\n",
194
+ " \"\"\"\n",
195
+ " for c in range(3):\n",
196
+ " image[:, :, c] = np.where(mask == 1,\n",
197
+ " image[:, :, c] *\n",
198
+ " (1 - alpha) + alpha * color[c] * 255,\n",
199
+ " image[:, :, c])\n",
200
+ " return image\n",
201
+ "\n",
202
+ "def plot_results(pil_img, scores, boxes, labels, masks=None):\n",
203
+ " plt.figure(figsize=(16,10))\n",
204
+ " np_image = np.array(pil_img)\n",
205
+ " ax = plt.gca()\n",
206
+ " colors = COLORS * 100\n",
207
+ " if masks is None:\n",
208
+ " masks = [None for _ in range(len(scores))]\n",
209
+ " assert len(scores) == len(boxes) == len(labels) == len(masks)\n",
210
+ " for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):\n",
211
+ " ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,\n",
212
+ " fill=False, color=c, linewidth=3))\n",
213
+ " text = f'{l}: {s:0.2f}'\n",
214
+ " ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n",
215
+ "\n",
216
+ " if mask is None:\n",
217
+ " continue\n",
218
+ " np_image = apply_mask(np_image, mask, c)\n",
219
+ "\n",
220
+ " padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)\n",
221
+ " padded_mask[1:-1, 1:-1] = mask\n",
222
+ " contours = find_contours(padded_mask, 0.5)\n",
223
+ " for verts in contours:\n",
224
+ " # Subtract the padding and flip (y, x) to (x, y)\n",
225
+ " verts = np.fliplr(verts) - 1\n",
226
+ " p = Polygon(verts, facecolor=\"none\", edgecolor=c)\n",
227
+ " ax.add_patch(p)\n",
228
+ "\n",
229
+ "\n",
230
+ " plt.imshow(np_image)\n",
231
+ " plt.axis('off')\n",
232
+ " plt.savefig('foo.png',bbox_inches='tight')\n",
233
+ " return 'foo.png'\n",
234
+ "\n",
235
+ "\n",
236
+ "def add_res(results, ax, color='green'):\n",
237
+ " #for tt in results.values():\n",
238
+ " if True:\n",
239
+ " bboxes = results['boxes']\n",
240
+ " labels = results['labels']\n",
241
+ " scores = results['scores']\n",
242
+ " #keep = scores >= 0.0\n",
243
+ " #bboxes = bboxes[keep].tolist()\n",
244
+ " #labels = labels[keep].tolist()\n",
245
+ " #scores = scores[keep].tolist()\n",
246
+ " #print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))\n",
247
+ " \n",
248
+ " colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']\n",
249
+ " \n",
250
+ " for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):\n",
251
+ " ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))\n",
252
+ " cls_name = ll if isinstance(ll,str) else CLASSES[ll]\n",
253
+ " text = f'{cls_name}: {ss:.2f}'\n",
254
+ " print(text)\n",
255
+ " ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n",
256
+ "\n",
257
+ "\n",
258
+ "def plot_inference(im, caption, approaches):\n",
259
+ " \n",
260
+ " choices = {\"Worker Helmet Separately\" : 1,\"Worker Helmet Vest\":2, \"Workers only\":3}\n",
261
+ " \n",
262
+ " \n",
263
+ "# mean-std normalize the input image (batch-size: 1)\n",
264
+ " img = transform(im).unsqueeze(0).cpu()\n",
265
+ "\n",
266
+ " # propagate through the model\n",
267
+ " memory_cache = model2(img, [caption], encode_and_save=True)\n",
268
+ " outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache)\n",
269
+ "\n",
270
+ " # keep only predictions with 0.7+ confidence\n",
271
+ " probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()\n",
272
+ " keep = (probas > 0.7).cpu()\n",
273
+ "\n",
274
+ " # convert boxes from [0; 1] to image scales\n",
275
+ " bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)\n",
276
+ "\n",
277
+ " # Extract the text spans predicted by each box\n",
278
+ " positive_tokens = (outputs[\"pred_logits\"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()\n",
279
+ " predicted_spans = defaultdict(str)\n",
280
+ " for tok in positive_tokens:\n",
281
+ " item, pos = tok\n",
282
+ " if pos < 255:\n",
283
+ " span = memory_cache[\"tokenized\"].token_to_chars(0, pos)\n",
284
+ " predicted_spans [item] += \" \" + caption[span.start:span.end]\n",
285
+ "\n",
286
+ " labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))]\n",
287
+ " caption = 'Caption: '+ caption\n",
288
+ " return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches]))\n",
289
+ " \n",
290
+ "\n",
291
+ "\n",
292
+ " \n",
293
+ "#BLIP Code\n",
294
+ "\n",
295
+ "\n",
296
+ "from modelsn.blip import blip_decoder\n",
297
+ "\n",
298
+ "image_size = 384\n",
299
+ "transform = transforms.Compose([\n",
300
+ " transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
301
+ " transforms.ToTensor(),\n",
302
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
303
+ " ]) \n",
304
+ "\n",
305
+ "model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
306
+ " \n",
307
+ "model = blip_decoder(pretrained=model_url, image_size=384, vit='base')\n",
308
+ "model.eval()\n",
309
+ "model = model.to(device)\n",
310
+ "\n",
311
+ "\n",
312
+ "from modelsn.blip_vqa import blip_vqa\n",
313
+ "\n",
314
+ "image_size_vq = 480\n",
315
+ "transform_vq = transforms.Compose([\n",
316
+ " transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),\n",
317
+ " transforms.ToTensor(),\n",
318
+ " transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
319
+ " ]) \n",
320
+ "\n",
321
+ "model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'\n",
322
+ " \n",
323
+ "model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')\n",
324
+ "model_vq.eval()\n",
325
+ "model_vq = model_vq.to(device)\n",
326
+ "\n",
327
+ "\n",
328
+ "\n",
329
+ "def inference(raw_image, approaches, question):\n",
330
+ " \n",
331
+ "\n",
332
+ " image = transform(raw_image).unsqueeze(0).to(device) \n",
333
+ " with torch.no_grad():\n",
334
+ " caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)\n",
335
+ "\n",
336
+ " return (plot_inference(raw_image, caption[0], approaches))\n",
337
+ " #return 'caption: '+caption[0]\n",
338
+ "\n",
339
+ " \n",
340
+ "\n",
341
+ " \n",
342
+ "#PPE Detection code\n",
343
+ "import numpy as np\n",
344
+ "import run_code\n",
345
+ "import gradio as gr\n",
346
+ " \n",
347
+ "\n",
348
+ "def sepia_call(caption, Input_Image, MDETR_im, Approach):\n",
349
+ " pil_image = Input_Image\n",
350
+ " open_cv_image = np.asarray(pil_image)\n",
351
+ " sepia_img = run_code.run(open_cv_image, Approach)\n",
352
+ " images = sepia_img['img']\n",
353
+ " texts= sepia_img['text']\n",
354
+ "\n",
355
+ " return (caption, MDETR_im, images, texts)\n",
356
+ "\n",
357
+ "\n",
358
+ "inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=[\"Worker Helmet Separately\",\"Worker Helmet Vest\", \"Workers only\"], type=\"value\", default=\"Worker Helmet Vest\", label=\"Model\"),\"textbox\"]\n",
359
+ "outputs = [gr.outputs.Textbox(label=\"Output\"), \"image\", \"image\", gr.outputs.Textbox(label=\"Output\")]\n",
360
+ "\n",
361
+ "\n",
362
+ "title = \"BLIP + MDETR + PPE Detection\"\n",
363
+ "\n",
364
+ "description = \"Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\"\n",
365
+ "\n",
366
+ "article = \"<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>\"\n",
367
+ "\n",
368
+ "\n",
369
+ "gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starry.jpg',\"Image Captioning\",\"None\"]]).launch(share=True,enable_queue=True,cache_examples=False)"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "raw",
374
+ "id": "b2729aa9",
375
+ "metadata": {},
376
+ "source": []
377
+ }
378
+ ],
379
+ "metadata": {
380
+ "kernelspec": {
381
+ "display_name": "Python 3 (ipykernel)",
382
+ "language": "python",
383
+ "name": "python3"
384
+ },
385
+ "language_info": {
386
+ "codemirror_mode": {
387
+ "name": "ipython",
388
+ "version": 3
389
+ },
390
+ "file_extension": ".py",
391
+ "mimetype": "text/x-python",
392
+ "name": "python",
393
+ "nbconvert_exporter": "python",
394
+ "pygments_lexer": "ipython3",
395
+ "version": "3.8.12"
396
+ }
397
+ },
398
+ "nbformat": 4,
399
+ "nbformat_minor": 5
400
+ }
configs/caption_coco.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ coco_gt_root: 'annotation/coco_gt'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
7
+
8
+ # size of vit model; base or large
9
+ vit: 'base'
10
+ vit_grad_ckpt: False
11
+ vit_ckpt_layer: 0
12
+ batch_size: 32
13
+ init_lr: 1e-5
14
+
15
+ # vit: 'large'
16
+ # vit_grad_ckpt: True
17
+ # vit_ckpt_layer: 5
18
+ # batch_size: 16
19
+ # init_lr: 2e-6
20
+
21
+ image_size: 384
22
+
23
+ # generation configs
24
+ max_length: 20
25
+ min_length: 5
26
+ num_beams: 3
27
+ prompt: 'a picture of '
28
+
29
+ # optimizer
30
+ weight_decay: 0.05
31
+ min_lr: 0
32
+ max_epoch: 5
33
+
configs/med_config.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30524,
19
+ "encoder_width": 768,
20
+ "add_cross_attention": true
21
+ }
configs/nlvr.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/NLVR2/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
6
+
7
+ #size of vit model; base or large
8
+ vit: 'base'
9
+ batch_size_train: 16
10
+ batch_size_test: 64
11
+ vit_grad_ckpt: False
12
+ vit_ckpt_layer: 0
13
+ max_epoch: 15
14
+
15
+ image_size: 384
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-5
20
+ min_lr: 0
21
+
configs/nocaps.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/nocaps/'
2
+ ann_root: 'annotation'
3
+
4
+ # set pretrained as a file path or an url
5
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
6
+
7
+ vit: 'base'
8
+ batch_size: 32
9
+
10
+ image_size: 384
11
+
12
+ max_length: 20
13
+ min_length: 5
14
+ num_beams: 3
15
+ prompt: 'a picture of '
configs/pretrain.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
2
+ '/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
3
+ ]
4
+ laion_path: ''
5
+
6
+ # size of vit model; base or large
7
+ vit: 'base'
8
+ vit_grad_ckpt: False
9
+ vit_ckpt_layer: 0
10
+
11
+ image_size: 224
12
+ batch_size: 75
13
+
14
+ queue_size: 57600
15
+ alpha: 0.4
16
+
17
+ # optimizer
18
+ weight_decay: 0.05
19
+ init_lr: 3e-4
20
+ min_lr: 1e-6
21
+ warmup_lr: 1e-6
22
+ lr_decay_rate: 0.9
23
+ max_epoch: 20
24
+ warmup_steps: 3000
25
+
26
+
27
+
configs/retrieval_coco.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/coco/images/'
2
+ ann_root: 'annotation'
3
+ dataset: 'coco'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 12
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 256
28
+ negative_all_rank: True
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/retrieval_flickr.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ image_root: '/export/share/datasets/vision/flickr30k/'
2
+ ann_root: 'annotation'
3
+ dataset: 'flickr'
4
+
5
+ # set pretrained as a file path or an url
6
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
7
+
8
+ # size of vit model; base or large
9
+
10
+ vit: 'base'
11
+ batch_size_train: 32
12
+ batch_size_test: 64
13
+ vit_grad_ckpt: True
14
+ vit_ckpt_layer: 4
15
+ init_lr: 1e-5
16
+
17
+ # vit: 'large'
18
+ # batch_size_train: 16
19
+ # batch_size_test: 32
20
+ # vit_grad_ckpt: True
21
+ # vit_ckpt_layer: 10
22
+ # init_lr: 5e-6
23
+
24
+ image_size: 384
25
+ queue_size: 57600
26
+ alpha: 0.4
27
+ k_test: 128
28
+ negative_all_rank: False
29
+
30
+ # optimizer
31
+ weight_decay: 0.05
32
+ min_lr: 0
33
+ max_epoch: 6
34
+
configs/vqa.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
2
+ vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
3
+ train_files: ['vqa_train','vqa_val','vg_qa']
4
+ ann_root: 'annotation'
5
+
6
+ # set pretrained as a file path or an url
7
+ pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
8
+
9
+ # size of vit model; base or large
10
+ vit: 'base'
11
+ batch_size_train: 16
12
+ batch_size_test: 32
13
+ vit_grad_ckpt: False
14
+ vit_ckpt_layer: 0
15
+ init_lr: 2e-5
16
+
17
+ image_size: 480
18
+
19
+ k_test: 128
20
+ inference: 'rank'
21
+
22
+ # optimizer
23
+ weight_decay: 0.05
24
+ min_lr: 0
25
+ max_epoch: 10
data/__init__.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from torchvision import transforms
4
+ from torchvision.transforms.functional import InterpolationMode
5
+
6
+ from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
7
+ from data.nocaps_dataset import nocaps_eval
8
+ from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
9
+ from data.vqa_dataset import vqa_dataset
10
+ from data.nlvr_dataset import nlvr_dataset
11
+ from data.pretrain_dataset import pretrain_dataset
12
+ from transform.randaugment import RandomAugment
13
+
14
+ def create_dataset(dataset, config, min_scale=0.5):
15
+
16
+ normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
17
+
18
+ transform_train = transforms.Compose([
19
+ transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
20
+ transforms.RandomHorizontalFlip(),
21
+ RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
22
+ 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
23
+ transforms.ToTensor(),
24
+ normalize,
25
+ ])
26
+ transform_test = transforms.Compose([
27
+ transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
28
+ transforms.ToTensor(),
29
+ normalize,
30
+ ])
31
+
32
+ if dataset=='pretrain':
33
+ dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
34
+ return dataset
35
+
36
+ elif dataset=='caption_coco':
37
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
38
+ val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
39
+ test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
40
+ return train_dataset, val_dataset, test_dataset
41
+
42
+ elif dataset=='nocaps':
43
+ val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
44
+ test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
45
+ return val_dataset, test_dataset
46
+
47
+ elif dataset=='retrieval_coco':
48
+ train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
49
+ val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
50
+ test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
51
+ return train_dataset, val_dataset, test_dataset
52
+
53
+ elif dataset=='retrieval_flickr':
54
+ train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
55
+ val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
56
+ test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
57
+ return train_dataset, val_dataset, test_dataset
58
+
59
+ elif dataset=='vqa':
60
+ train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
61
+ train_files = config['train_files'], split='train')
62
+ test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
63
+ return train_dataset, test_dataset
64
+
65
+ elif dataset=='nlvr':
66
+ train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
67
+ val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
68
+ test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
69
+ return train_dataset, val_dataset, test_dataset
70
+
71
+
72
+ def create_sampler(datasets, shuffles, num_tasks, global_rank):
73
+ samplers = []
74
+ for dataset,shuffle in zip(datasets,shuffles):
75
+ sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
76
+ samplers.append(sampler)
77
+ return samplers
78
+
79
+
80
+ def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
81
+ loaders = []
82
+ for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
83
+ if is_train:
84
+ shuffle = (sampler is None)
85
+ drop_last = True
86
+ else:
87
+ shuffle = False
88
+ drop_last = False
89
+ loader = DataLoader(
90
+ dataset,
91
+ batch_size=bs,
92
+ num_workers=n_worker,
93
+ pin_memory=True,
94
+ sampler=sampler,
95
+ shuffle=shuffle,
96
+ collate_fn=collate_fn,
97
+ drop_last=drop_last,
98
+ )
99
+ loaders.append(loader)
100
+ return loaders
101
+
data/coco_karpathy_dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class coco_karpathy_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. coco/images/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
18
+ filename = 'coco_karpathy_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class coco_karpathy_caption_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. coco/images/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
61
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ def __len__(self):
70
+ return len(self.annotation)
71
+
72
+ def __getitem__(self, index):
73
+
74
+ ann = self.annotation[index]
75
+
76
+ image_path = os.path.join(self.image_root,ann['image'])
77
+ image = Image.open(image_path).convert('RGB')
78
+ image = self.transform(image)
79
+
80
+ img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
81
+
82
+ return image, int(img_id)
83
+
84
+
85
+ class coco_karpathy_retrieval_eval(Dataset):
86
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
87
+ '''
88
+ image_root (string): Root directory of images (e.g. coco/images/)
89
+ ann_root (string): directory to store the annotation file
90
+ split (string): val or test
91
+ '''
92
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
93
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
94
+ filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
95
+
96
+ download_url(urls[split],ann_root)
97
+
98
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
99
+ self.transform = transform
100
+ self.image_root = image_root
101
+
102
+ self.text = []
103
+ self.image = []
104
+ self.txt2img = {}
105
+ self.img2txt = {}
106
+
107
+ txt_id = 0
108
+ for img_id, ann in enumerate(self.annotation):
109
+ self.image.append(ann['image'])
110
+ self.img2txt[img_id] = []
111
+ for i, caption in enumerate(ann['caption']):
112
+ self.text.append(pre_caption(caption,max_words))
113
+ self.img2txt[img_id].append(txt_id)
114
+ self.txt2img[txt_id] = img_id
115
+ txt_id += 1
116
+
117
+ def __len__(self):
118
+ return len(self.annotation)
119
+
120
+ def __getitem__(self, index):
121
+
122
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
123
+ image = Image.open(image_path).convert('RGB')
124
+ image = self.transform(image)
125
+
126
+ return image, index
data/flickr30k_dataset.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ from data.utils import pre_caption
10
+
11
+ class flickr30k_train(Dataset):
12
+ def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
13
+ '''
14
+ image_root (string): Root directory of images (e.g. flickr30k/)
15
+ ann_root (string): directory to store the annotation file
16
+ '''
17
+ url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
18
+ filename = 'flickr30k_train.json'
19
+
20
+ download_url(url,ann_root)
21
+
22
+ self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
23
+ self.transform = transform
24
+ self.image_root = image_root
25
+ self.max_words = max_words
26
+ self.prompt = prompt
27
+
28
+ self.img_ids = {}
29
+ n = 0
30
+ for ann in self.annotation:
31
+ img_id = ann['image_id']
32
+ if img_id not in self.img_ids.keys():
33
+ self.img_ids[img_id] = n
34
+ n += 1
35
+
36
+ def __len__(self):
37
+ return len(self.annotation)
38
+
39
+ def __getitem__(self, index):
40
+
41
+ ann = self.annotation[index]
42
+
43
+ image_path = os.path.join(self.image_root,ann['image'])
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = self.transform(image)
46
+
47
+ caption = self.prompt+pre_caption(ann['caption'], self.max_words)
48
+
49
+ return image, caption, self.img_ids[ann['image_id']]
50
+
51
+
52
+ class flickr30k_retrieval_eval(Dataset):
53
+ def __init__(self, transform, image_root, ann_root, split, max_words=30):
54
+ '''
55
+ image_root (string): Root directory of images (e.g. flickr30k/)
56
+ ann_root (string): directory to store the annotation file
57
+ split (string): val or test
58
+ '''
59
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
60
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
61
+ filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
62
+
63
+ download_url(urls[split],ann_root)
64
+
65
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
66
+ self.transform = transform
67
+ self.image_root = image_root
68
+
69
+ self.text = []
70
+ self.image = []
71
+ self.txt2img = {}
72
+ self.img2txt = {}
73
+
74
+ txt_id = 0
75
+ for img_id, ann in enumerate(self.annotation):
76
+ self.image.append(ann['image'])
77
+ self.img2txt[img_id] = []
78
+ for i, caption in enumerate(ann['caption']):
79
+ self.text.append(pre_caption(caption,max_words))
80
+ self.img2txt[img_id].append(txt_id)
81
+ self.txt2img[txt_id] = img_id
82
+ txt_id += 1
83
+
84
+ def __len__(self):
85
+ return len(self.annotation)
86
+
87
+ def __getitem__(self, index):
88
+
89
+ image_path = os.path.join(self.image_root, self.annotation[index]['image'])
90
+ image = Image.open(image_path).convert('RGB')
91
+ image = self.transform(image)
92
+
93
+ return image, index
data/nlvr_dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+ from torchvision.datasets.utils import download_url
7
+
8
+ from PIL import Image
9
+
10
+ from data.utils import pre_caption
11
+
12
+ class nlvr_dataset(Dataset):
13
+ def __init__(self, transform, image_root, ann_root, split):
14
+ '''
15
+ image_root (string): Root directory of images
16
+ ann_root (string): directory to store the annotation file
17
+ split (string): train, val or test
18
+ '''
19
+ urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
20
+ 'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
21
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
22
+ filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
23
+
24
+ download_url(urls[split],ann_root)
25
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
26
+
27
+ self.transform = transform
28
+ self.image_root = image_root
29
+
30
+
31
+ def __len__(self):
32
+ return len(self.annotation)
33
+
34
+
35
+ def __getitem__(self, index):
36
+
37
+ ann = self.annotation[index]
38
+
39
+ image0_path = os.path.join(self.image_root,ann['images'][0])
40
+ image0 = Image.open(image0_path).convert('RGB')
41
+ image0 = self.transform(image0)
42
+
43
+ image1_path = os.path.join(self.image_root,ann['images'][1])
44
+ image1 = Image.open(image1_path).convert('RGB')
45
+ image1 = self.transform(image1)
46
+
47
+ sentence = pre_caption(ann['sentence'], 40)
48
+
49
+ if ann['label']=='True':
50
+ label = 1
51
+ else:
52
+ label = 0
53
+
54
+ words = sentence.split(' ')
55
+
56
+ if 'left' not in words and 'right' not in words:
57
+ if random.random()<0.5:
58
+ return image0, image1, sentence, label
59
+ else:
60
+ return image1, image0, sentence, label
61
+ else:
62
+ if random.random()<0.5:
63
+ return image0, image1, sentence, label
64
+ else:
65
+ new_words = []
66
+ for word in words:
67
+ if word=='left':
68
+ new_words.append('right')
69
+ elif word=='right':
70
+ new_words.append('left')
71
+ else:
72
+ new_words.append(word)
73
+
74
+ sentence = ' '.join(new_words)
75
+ return image1, image0, sentence, label
76
+
77
+
78
+
data/nocaps_dataset.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ from torch.utils.data import Dataset
5
+ from torchvision.datasets.utils import download_url
6
+
7
+ from PIL import Image
8
+
9
+ class nocaps_eval(Dataset):
10
+ def __init__(self, transform, image_root, ann_root, split):
11
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
12
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
13
+ filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
14
+
15
+ download_url(urls[split],ann_root)
16
+
17
+ self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
18
+ self.transform = transform
19
+ self.image_root = image_root
20
+
21
+ def __len__(self):
22
+ return len(self.annotation)
23
+
24
+ def __getitem__(self, index):
25
+
26
+ ann = self.annotation[index]
27
+
28
+ image_path = os.path.join(self.image_root,ann['image'])
29
+ image = Image.open(image_path).convert('RGB')
30
+ image = self.transform(image)
31
+
32
+ return image, int(ann['img_id'])
data/pretrain_dataset.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+
5
+ from torch.utils.data import Dataset
6
+
7
+ from PIL import Image
8
+ from PIL import ImageFile
9
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
10
+ Image.MAX_IMAGE_PIXELS = None
11
+
12
+ from data.utils import pre_caption
13
+ import os,glob
14
+
15
+ class pretrain_dataset(Dataset):
16
+ def __init__(self, ann_file, laion_path, transform):
17
+
18
+ self.ann_pretrain = []
19
+ for f in ann_file:
20
+ print('loading '+f)
21
+ ann = json.load(open(f,'r'))
22
+ self.ann_pretrain += ann
23
+
24
+ self.laion_path = laion_path
25
+ if self.laion_path:
26
+ self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
27
+
28
+ print('loading '+self.laion_files[0])
29
+ with open(self.laion_files[0],'r') as f:
30
+ self.ann_laion = json.load(f)
31
+
32
+ self.annotation = self.ann_pretrain + self.ann_laion
33
+ else:
34
+ self.annotation = self.ann_pretrain
35
+
36
+ self.transform = transform
37
+
38
+
39
+ def reload_laion(self, epoch):
40
+ n = epoch%len(self.laion_files)
41
+ print('loading '+self.laion_files[n])
42
+ with open(self.laion_files[n],'r') as f:
43
+ self.ann_laion = json.load(f)
44
+
45
+ self.annotation = self.ann_pretrain + self.ann_laion
46
+
47
+
48
+ def __len__(self):
49
+ return len(self.annotation)
50
+
51
+ def __getitem__(self, index):
52
+
53
+ ann = self.annotation[index]
54
+
55
+ image = Image.open(ann['image']).convert('RGB')
56
+ image = self.transform(image)
57
+ caption = pre_caption(ann['caption'],30)
58
+
59
+ return image, caption
data/utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+ import utils
9
+
10
+ def pre_caption(caption,max_words=50):
11
+ caption = re.sub(
12
+ r"([.!\"()*#:;~])",
13
+ ' ',
14
+ caption.lower(),
15
+ )
16
+ caption = re.sub(
17
+ r"\s{2,}",
18
+ ' ',
19
+ caption,
20
+ )
21
+ caption = caption.rstrip('\n')
22
+ caption = caption.strip(' ')
23
+
24
+ #truncate caption
25
+ caption_words = caption.split(' ')
26
+ if len(caption_words)>max_words:
27
+ caption = ' '.join(caption_words[:max_words])
28
+
29
+ return caption
30
+
31
+ def pre_question(question,max_ques_words=50):
32
+ question = re.sub(
33
+ r"([.!\"()*#:;~])",
34
+ '',
35
+ question.lower(),
36
+ )
37
+ question = question.rstrip(' ')
38
+
39
+ #truncate question
40
+ question_words = question.split(' ')
41
+ if len(question_words)>max_ques_words:
42
+ question = ' '.join(question_words[:max_ques_words])
43
+
44
+ return question
45
+
46
+
47
+ def save_result(result, result_dir, filename, remove_duplicate=''):
48
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
49
+ final_result_file = os.path.join(result_dir, '%s.json'%filename)
50
+
51
+ json.dump(result,open(result_file,'w'))
52
+
53
+ dist.barrier()
54
+
55
+ if utils.is_main_process():
56
+ # combine results from all processes
57
+ result = []
58
+
59
+ for rank in range(utils.get_world_size()):
60
+ result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
61
+ res = json.load(open(result_file,'r'))
62
+ result += res
63
+
64
+ if remove_duplicate:
65
+ result_new = []
66
+ id_list = []
67
+ for res in result:
68
+ if res[remove_duplicate] not in id_list:
69
+ id_list.append(res[remove_duplicate])
70
+ result_new.append(res)
71
+ result = result_new
72
+
73
+ json.dump(result,open(final_result_file,'w'))
74
+ print('result file saved to %s'%final_result_file)
75
+
76
+ return final_result_file
77
+
78
+
79
+
80
+ from pycocotools.coco import COCO
81
+ from pycocoevalcap.eval import COCOEvalCap
82
+ from torchvision.datasets.utils import download_url
83
+
84
+ def coco_caption_eval(coco_gt_root, results_file, split):
85
+ urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
86
+ 'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
87
+ filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
88
+
89
+ download_url(urls[split],coco_gt_root)
90
+ annotation_file = os.path.join(coco_gt_root,filenames[split])
91
+
92
+ # create coco object and coco_result object
93
+ coco = COCO(annotation_file)
94
+ coco_result = coco.loadRes(results_file)
95
+
96
+ # create coco_eval object by taking coco and coco_result
97
+ coco_eval = COCOEvalCap(coco, coco_result)
98
+
99
+ # evaluate on a subset of images by setting
100
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
101
+ # please remove this line when evaluating the full validation set
102
+ # coco_eval.params['image_id'] = coco_result.getImgIds()
103
+
104
+ # evaluate results
105
+ # SPICE will take a few minutes the first time, but speeds up due to caching
106
+ coco_eval.evaluate()
107
+
108
+ # print output evaluation scores
109
+ for metric, score in coco_eval.eval.items():
110
+ print(f'{metric}: {score:.3f}')
111
+
112
+ return coco_eval
data/vqa_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import random
4
+ from PIL import Image
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from data.utils import pre_question
9
+
10
+ from torchvision.datasets.utils import download_url
11
+
12
+ class vqa_dataset(Dataset):
13
+ def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
14
+ self.split = split
15
+
16
+ self.transform = transform
17
+ self.vqa_root = vqa_root
18
+ self.vg_root = vg_root
19
+
20
+ if split=='train':
21
+ urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
22
+ 'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
23
+ 'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
24
+
25
+ self.annotation = []
26
+ for f in train_files:
27
+ download_url(urls[f],ann_root)
28
+ self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
29
+ else:
30
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
31
+ self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
32
+
33
+ download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
34
+ self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
35
+
36
+
37
+ def __len__(self):
38
+ return len(self.annotation)
39
+
40
+ def __getitem__(self, index):
41
+
42
+ ann = self.annotation[index]
43
+
44
+ if ann['dataset']=='vqa':
45
+ image_path = os.path.join(self.vqa_root,ann['image'])
46
+ elif ann['dataset']=='vg':
47
+ image_path = os.path.join(self.vg_root,ann['image'])
48
+
49
+ image = Image.open(image_path).convert('RGB')
50
+ image = self.transform(image)
51
+
52
+ if self.split == 'test':
53
+ question = pre_question(ann['question'])
54
+ question_id = ann['question_id']
55
+ return image, question, question_id
56
+
57
+
58
+ elif self.split=='train':
59
+
60
+ question = pre_question(ann['question'])
61
+
62
+ if ann['dataset']=='vqa':
63
+ answer_weight = {}
64
+ for answer in ann['answer']:
65
+ if answer in answer_weight.keys():
66
+ answer_weight[answer] += 1/len(ann['answer'])
67
+ else:
68
+ answer_weight[answer] = 1/len(ann['answer'])
69
+
70
+ answers = list(answer_weight.keys())
71
+ weights = list(answer_weight.values())
72
+
73
+ elif ann['dataset']=='vg':
74
+ answers = [ann['answer']]
75
+ weights = [0.2]
76
+
77
+ return image, question, answers, weights
78
+
79
+
80
+ def vqa_collate_fn(batch):
81
+ image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
82
+ for image, question, answer, weights in batch:
83
+ image_list.append(image)
84
+ question_list.append(question)
85
+ weight_list += weights
86
+ answer_list += answer
87
+ n.append(len(answer))
88
+ return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
elephant.jpg ADDED
eval_nocaps.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from data import create_dataset, create_sampler, create_loader
28
+ from data.utils import save_result
29
+
30
+ @torch.no_grad()
31
+ def evaluate(model, data_loader, device, config):
32
+ # evaluate
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+ print_freq = 10
38
+
39
+ result = []
40
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
41
+
42
+ image = image.to(device)
43
+
44
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
45
+ min_length=config['min_length'], repetition_penalty=1.1)
46
+
47
+ for caption, img_id in zip(captions, image_id):
48
+ result.append({"image_id": img_id.item(), "caption": caption})
49
+
50
+ return result
51
+
52
+
53
+ def main(args, config):
54
+ utils.init_distributed_mode(args)
55
+
56
+ device = torch.device(args.device)
57
+
58
+ # fix the seed for reproducibility
59
+ seed = args.seed + utils.get_rank()
60
+ torch.manual_seed(seed)
61
+ np.random.seed(seed)
62
+ random.seed(seed)
63
+ cudnn.benchmark = True
64
+
65
+ #### Dataset ####
66
+ print("Creating captioning dataset")
67
+ val_dataset, test_dataset = create_dataset('nocaps', config)
68
+
69
+ if args.distributed:
70
+ num_tasks = utils.get_world_size()
71
+ global_rank = utils.get_rank()
72
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
73
+ else:
74
+ samplers = [None,None]
75
+
76
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
77
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
78
+ is_trains=[False, False], collate_fns=[None,None])
79
+
80
+ #### Model ####
81
+ print("Creating model")
82
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
83
+ prompt=config['prompt'])
84
+
85
+ model = model.to(device)
86
+
87
+ model_without_ddp = model
88
+ if args.distributed:
89
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
90
+ model_without_ddp = model.module
91
+
92
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
93
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
94
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
95
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
101
+ parser.add_argument('--output_dir', default='output/NoCaps')
102
+ parser.add_argument('--device', default='cuda')
103
+ parser.add_argument('--seed', default=42, type=int)
104
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
105
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
106
+ parser.add_argument('--distributed', default=True, type=bool)
107
+ args = parser.parse_args()
108
+
109
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
110
+
111
+ args.result_dir = os.path.join(args.output_dir, 'result')
112
+
113
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
114
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
115
+
116
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
117
+
118
+ main(args, config)
examples/ex1.jpg ADDED
examples/ex2.jpg ADDED
examples/ex3.jpg ADDED
extras/.DS_Store ADDED
Binary file (6.15 kB). View file
 
extras/sample-images/0.JPG ADDED
extras/sample-images/1.JPG ADDED
extras/sample-images/10.jpg ADDED
extras/sample-images/2.jpg ADDED
extras/sample-images/3.jpg ADDED
extras/sample-images/4.jpg ADDED
extras/sample-images/5.jpg ADDED
extras/sample-images/6.JPG ADDED
extras/sample-images/7.JPG ADDED
extras/sample-images/8.jpg ADDED
extras/sample-images/9.jpg ADDED
foo.png ADDED
gradio_cached_examples/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Output
2
+ caption: a painting of a starry night over a city
local_run.ipynb ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Running on local URL: http://127.0.0.1:7860/\n",
13
+ "\n",
14
+ "To create a public link, set `share=True` in `launch()`.\n"
15
+ ]
16
+ },
17
+ {
18
+ "data": {
19
+ "text/html": [
20
+ "\n",
21
+ " <iframe\n",
22
+ " width=\"900\"\n",
23
+ " height=\"500\"\n",
24
+ " src=\"http://127.0.0.1:7860/\"\n",
25
+ " frameborder=\"0\"\n",
26
+ " allowfullscreen\n",
27
+ " \n",
28
+ " ></iframe>\n",
29
+ " "
30
+ ],
31
+ "text/plain": [
32
+ "<IPython.lib.display.IFrame at 0x7fbca787f520>"
33
+ ]
34
+ },
35
+ "metadata": {},
36
+ "output_type": "display_data"
37
+ },
38
+ {
39
+ "data": {
40
+ "text/plain": [
41
+ "(<fastapi.applications.FastAPI at 0x7fbcc67ceeb0>,\n",
42
+ " 'http://127.0.0.1:7860/',\n",
43
+ " None)"
44
+ ]
45
+ },
46
+ "execution_count": 1,
47
+ "metadata": {},
48
+ "output_type": "execute_result"
49
+ },
50
+ {
51
+ "name": "stderr",
52
+ "output_type": "stream",
53
+ "text": [
54
+ "2022-02-09 14:10:22.417549: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
55
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
56
+ ]
57
+ },
58
+ {
59
+ "name": "stdout",
60
+ "output_type": "stream",
61
+ "text": [
62
+ "\n",
63
+ "\n",
64
+ "\n",
65
+ "Total workers: 5\n",
66
+ "Number of Helmets: 4\n",
67
+ "Number of Vests: 0\n",
68
+ "dict vals:\n",
69
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
70
+ "\n",
71
+ "\n",
72
+ "\n",
73
+ "Total workers: 5\n",
74
+ "dict vals:\n",
75
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
76
+ "\n",
77
+ "\n",
78
+ "\n",
79
+ "Total workers: 5\n",
80
+ "Workers wearing helmet and vest: 0\n",
81
+ "Workers wearing only vest: 0\n",
82
+ "Workers wearing only helmet: 5\n",
83
+ "dict vals:\n",
84
+ "{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n",
85
+ "\n",
86
+ "\n",
87
+ "\n",
88
+ "Total workers: 5\n",
89
+ "Number of Helmets: 4\n",
90
+ "Number of Vests: 0\n",
91
+ "dict vals:\n",
92
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
93
+ "WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fbc729998b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
94
+ "\n",
95
+ "\n",
96
+ "\n",
97
+ "Total workers: 5\n",
98
+ "Workers wearing helmet and vest: 0\n",
99
+ "Workers wearing only vest: 0\n",
100
+ "Workers wearing only helmet: 5\n",
101
+ "dict vals:\n",
102
+ "{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n",
103
+ "WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fbc979e9ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
104
+ "\n",
105
+ "\n",
106
+ "\n",
107
+ "Total workers: 3\n",
108
+ "dict vals:\n",
109
+ "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
110
+ "\n",
111
+ "\n",
112
+ "\n",
113
+ "Total workers: 3\n",
114
+ "Workers wearing helmet and vest: 3\n",
115
+ "Workers wearing only vest: 0\n",
116
+ "Workers wearing only helmet: 0\n",
117
+ "dict vals:\n",
118
+ "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n",
119
+ "\n",
120
+ "\n",
121
+ "\n",
122
+ "Total workers: 3\n",
123
+ "Number of Helmets: 3\n",
124
+ "Number of Vests: 1\n",
125
+ "dict vals:\n",
126
+ "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
127
+ "\n",
128
+ "\n",
129
+ "\n",
130
+ "Total workers: 5\n",
131
+ "Number of Helmets: 4\n",
132
+ "Number of Vests: 0\n",
133
+ "dict vals:\n",
134
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
135
+ "\n",
136
+ "\n",
137
+ "\n",
138
+ "Total workers: 6\n",
139
+ "Workers wearing helmet and vest: 0\n",
140
+ "Workers wearing only vest: 0\n",
141
+ "Workers wearing only helmet: 4\n",
142
+ "Workers not wearing helmet and vest: 2\n",
143
+ "\n",
144
+ "\n",
145
+ "dict vals:\n",
146
+ "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
147
+ "\n",
148
+ "\n",
149
+ "\n",
150
+ "Total workers: 6\n",
151
+ "dict vals:\n",
152
+ "{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
153
+ "\n",
154
+ "\n",
155
+ "\n",
156
+ "Total workers: 5\n",
157
+ "Number of Helmets: 4\n",
158
+ "Number of Vests: 0\n",
159
+ "dict vals:\n",
160
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
161
+ "\n",
162
+ "\n",
163
+ "\n",
164
+ "Total workers: 6\n",
165
+ "dict vals:\n",
166
+ "{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
167
+ "\n",
168
+ "\n",
169
+ "\n",
170
+ "Total workers: 6\n",
171
+ "Workers wearing helmet and vest: 0\n",
172
+ "Workers wearing only vest: 0\n",
173
+ "Workers wearing only helmet: 4\n",
174
+ "Workers not wearing helmet and vest: 2\n",
175
+ "\n",
176
+ "\n",
177
+ "dict vals:\n",
178
+ "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
179
+ "\n",
180
+ "\n",
181
+ "\n",
182
+ "Total workers: 1\n",
183
+ "Number of Helmets: 1\n",
184
+ "Number of Vests: 0\n",
185
+ "dict vals:\n",
186
+ "{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
187
+ "\n",
188
+ "\n",
189
+ "\n",
190
+ "Total workers: 1\n",
191
+ "Workers wearing helmet and vest: 0\n",
192
+ "Workers wearing only vest: 0\n",
193
+ "Workers wearing only helmet: 1\n",
194
+ "dict vals:\n",
195
+ "{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n",
196
+ "\n",
197
+ "\n",
198
+ "\n",
199
+ "Total workers: 1\n",
200
+ "dict vals:\n",
201
+ "{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
202
+ "\n",
203
+ "\n",
204
+ "\n",
205
+ "Total workers: 1\n",
206
+ "Workers wearing helmet and vest: 0\n",
207
+ "Workers wearing only vest: 0\n",
208
+ "Workers wearing only helmet: 1\n",
209
+ "dict vals:\n",
210
+ "{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n",
211
+ "\n",
212
+ "\n",
213
+ "\n",
214
+ "Total workers: 5\n",
215
+ "Number of Helmets: 4\n",
216
+ "Number of Vests: 0\n",
217
+ "dict vals:\n",
218
+ "{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
219
+ "\n",
220
+ "\n",
221
+ "\n",
222
+ "Total workers: 6\n",
223
+ "Workers wearing helmet and vest: 0\n",
224
+ "Workers wearing only vest: 0\n",
225
+ "Workers wearing only helmet: 4\n",
226
+ "Workers not wearing helmet and vest: 2\n",
227
+ "\n",
228
+ "\n",
229
+ "dict vals:\n",
230
+ "{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
231
+ "\n",
232
+ "\n",
233
+ "\n",
234
+ "Total workers: 3\n",
235
+ "Workers wearing helmet and vest: 3\n",
236
+ "Workers wearing only vest: 0\n",
237
+ "Workers wearing only helmet: 0\n",
238
+ "dict vals:\n",
239
+ "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n",
240
+ "\n",
241
+ "\n",
242
+ "\n",
243
+ "Total workers: 3\n",
244
+ "dict vals:\n",
245
+ "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
246
+ "\n",
247
+ "\n",
248
+ "\n",
249
+ "Total workers: 3\n",
250
+ "Number of Helmets: 3\n",
251
+ "Number of Vests: 1\n",
252
+ "dict vals:\n",
253
+ "{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
254
+ "\n",
255
+ "\n",
256
+ "\n",
257
+ "Total workers: 3\n",
258
+ "Workers wearing helmet and vest: 3\n",
259
+ "Workers wearing only vest: 0\n",
260
+ "Workers wearing only helmet: 0\n",
261
+ "dict vals:\n",
262
+ "{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n"
263
+ ]
264
+ }
265
+ ],
266
+ "source": [
267
+ "import numpy as np\n",
268
+ "import run_code\n",
269
+ "import cv2\n",
270
+ "import gradio as gr\n",
271
+ "\n",
272
+ "\n",
273
+ "def sepia(Input_Image, Approach):\n",
274
+ " pil_image = Input_Image\n",
275
+ " open_cv_image = np.asarray(pil_image)\n",
276
+ " # Convert RGB to BGR\n",
277
+ " #open_cv_image = open_cv_image[:, :, ::-1].copy()\n",
278
+ " #Approach = 3\n",
279
+ " sepia_img = run_code.run(open_cv_image, Approach)\n",
280
+ " images = sepia_img['img']\n",
281
+ " texts= sepia_img['text']\n",
282
+ " #print (labels)\n",
283
+ " return images, texts\n",
284
+ "\n",
285
+ "image = [gr.inputs.Image(type=\"pil\"), gr.inputs.Radio([1, 2, 3])]\n",
286
+ "#output = [\"image\", gr.outputs.Label(num_top_classes=4)]\n",
287
+ "output = [\"image\", gr.outputs.Textbox(type=\"auto\")]\n",
288
+ "#output = gr.outputs.Label(num_top_classes=4)\n",
289
+ "\n",
290
+ "title=\"Real-time Detection of Personal-Protective-Equipment (PPE)\"\n",
291
+ "description=\"This demo is the implementation of Real-time Detection of Personal-Protective-Equipment (PPE) paper https://github.com/ciber-lab/pictor-ppe\" \\\n",
292
+ " \" - by Sanjay Kamath \"\n",
293
+ "examples = [[\"examples/ex1.jpg\", 1], [\"examples/ex2.jpg\", 2], [\"examples/ex3.jpg\", 3]]\n",
294
+ "\n",
295
+ "#iface = gr.Interface(sepia , [ gr.inputs.Image(shape=(200, 200)), gr.inputs.Radio([1, 2, 3])], \"image\", title=title,\n",
296
+ "# examples = [[\"examples/ex1.jpg\"], [\"examples/ex2.jpg\"], [\"examples/ex3.jpg\"]],\n",
297
+ "# description=description)\n",
298
+ "\n",
299
+ "iface = gr.Interface(fn=sepia, inputs=image, outputs=output, title=title, description=description, examples=examples)\n",
300
+ "\n",
301
+ "iface.launch()"
302
+ ]
303
+ },
304
+ {
305
+ "cell_type": "code",
306
+ "execution_count": null,
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": []
310
+ },
311
+ {
312
+ "cell_type": "code",
313
+ "execution_count": null,
314
+ "metadata": {},
315
+ "outputs": [],
316
+ "source": []
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {},
322
+ "outputs": [],
323
+ "source": []
324
+ }
325
+ ],
326
+ "metadata": {
327
+ "kernelspec": {
328
+ "display_name": "Python 3 (ipykernel)",
329
+ "language": "python",
330
+ "name": "python3"
331
+ },
332
+ "language_info": {
333
+ "codemirror_mode": {
334
+ "name": "ipython",
335
+ "version": 3
336
+ },
337
+ "file_extension": ".py",
338
+ "mimetype": "text/x-python",
339
+ "name": "python",
340
+ "nbconvert_exporter": "python",
341
+ "pygments_lexer": "ipython3",
342
+ "version": "3.8.12"
343
+ }
344
+ },
345
+ "nbformat": 4,
346
+ "nbformat_minor": 4
347
+ }
model-data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ec800aa5acdd9719ff5e63b34d1374e5c8a31e17f38f3a8250bf1aeeac1a972
3
+ size 246910096
model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:317831ba378b8ec02e24e57859876eb0348284c8a75155143c9df85ee478c47b
3
+ size 246931600
model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d06d4956d0f6b3ac71f02e103e9efdc4b222ce83aeae232f65ee6c04ee1dd2d7
3
+ size 246867088
model-data/weights/readme.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Download the trained weights of YOLO models ([Google Drive folder](https://drive.google.com/drive/folders/13tCdROHnS0c5VibW1VO8pOEj0rXEvvGj?usp=sharing)) and put in this folder.
modelsn/__init__.py ADDED
File without changes