Spaces:
Runtime error
Runtime error
rsanjaykamath
commited on
Commit
•
7fc7f3d
1
Parent(s):
eb43f71
push
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .idea/.gitignore +3 -0
- LICENSE.txt +12 -0
- README 2.md +46 -0
- README.md +40 -6
- __pycache__/run_code.cpython-38.pyc +0 -0
- app.py +232 -0
- app_run.ipynb +400 -0
- configs/caption_coco.yaml +33 -0
- configs/med_config.json +21 -0
- configs/nlvr.yaml +21 -0
- configs/nocaps.yaml +15 -0
- configs/pretrain.yaml +27 -0
- configs/retrieval_coco.yaml +34 -0
- configs/retrieval_flickr.yaml +34 -0
- configs/vqa.yaml +25 -0
- data/__init__.py +101 -0
- data/coco_karpathy_dataset.py +126 -0
- data/flickr30k_dataset.py +93 -0
- data/nlvr_dataset.py +78 -0
- data/nocaps_dataset.py +32 -0
- data/pretrain_dataset.py +59 -0
- data/utils.py +112 -0
- data/vqa_dataset.py +88 -0
- elephant.jpg +0 -0
- eval_nocaps.py +118 -0
- examples/ex1.jpg +0 -0
- examples/ex2.jpg +0 -0
- examples/ex3.jpg +0 -0
- extras/.DS_Store +0 -0
- extras/sample-images/0.JPG +0 -0
- extras/sample-images/1.JPG +0 -0
- extras/sample-images/10.jpg +0 -0
- extras/sample-images/2.jpg +0 -0
- extras/sample-images/3.jpg +0 -0
- extras/sample-images/4.jpg +0 -0
- extras/sample-images/5.jpg +0 -0
- extras/sample-images/6.JPG +0 -0
- extras/sample-images/7.JPG +0 -0
- extras/sample-images/8.jpg +0 -0
- extras/sample-images/9.jpg +0 -0
- foo.png +0 -0
- gradio_cached_examples/log.csv +2 -0
- local_run.ipynb +347 -0
- model-data/.DS_Store +0 -0
- model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5 +3 -0
- model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5 +3 -0
- model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5 +3 -0
- model-data/weights/readme.md +1 -0
- modelsn/__init__.py +0 -0
.DS_Store
ADDED
Binary file (10.2 kB). View file
|
|
.idea/.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# Default ignored files
|
3 |
+
/workspace.xml
|
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README 2.md
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: PPE_Detection
|
3 |
+
emoji: 💩
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
app_file: app.py
|
8 |
+
pinned: false
|
9 |
+
license: other
|
10 |
+
---
|
11 |
+
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio`, `streamlit`, or `static`
|
28 |
+
|
29 |
+
`sdk_version` : _string_
|
30 |
+
Only applicable for `streamlit` SDK.
|
31 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
+
|
33 |
+
`app_file`: _string_
|
34 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
|
35 |
+
Path is relative to the root of the repository.
|
36 |
+
|
37 |
+
`models`: _List[string]_
|
38 |
+
HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
|
39 |
+
Will be parsed automatically from your code if not specified here.
|
40 |
+
|
41 |
+
`datasets`: _List[string]_
|
42 |
+
HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
|
43 |
+
Will be parsed automatically from your code if not specified here.
|
44 |
+
|
45 |
+
`pinned`: _boolean_
|
46 |
+
Whether the Space stays on top of your list.
|
README.md
CHANGED
@@ -1,12 +1,46 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
-
license:
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: BLIP
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: blue
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
+
license: bsd-3-clause
|
10 |
---
|
11 |
|
12 |
+
# Configuration
|
13 |
+
|
14 |
+
`title`: _string_
|
15 |
+
Display title for the Space
|
16 |
+
|
17 |
+
`emoji`: _string_
|
18 |
+
Space emoji (emoji-only character allowed)
|
19 |
+
|
20 |
+
`colorFrom`: _string_
|
21 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
22 |
+
|
23 |
+
`colorTo`: _string_
|
24 |
+
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
25 |
+
|
26 |
+
`sdk`: _string_
|
27 |
+
Can be either `gradio`, `streamlit`, or `static`
|
28 |
+
|
29 |
+
`sdk_version` : _string_
|
30 |
+
Only applicable for `streamlit` SDK.
|
31 |
+
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
32 |
+
|
33 |
+
`app_file`: _string_
|
34 |
+
Path to your main application file (which contains either `gradio` or `streamlit` Python code, or `static` html code).
|
35 |
+
Path is relative to the root of the repository.
|
36 |
+
|
37 |
+
`models`: _List[string]_
|
38 |
+
HF model IDs (like "gpt2" or "deepset/roberta-base-squad2") used in the Space.
|
39 |
+
Will be parsed automatically from your code if not specified here.
|
40 |
+
|
41 |
+
`datasets`: _List[string]_
|
42 |
+
HF dataset IDs (like "common_voice" or "oscar-corpus/OSCAR-2109") used in the Space.
|
43 |
+
Will be parsed automatically from your code if not specified here.
|
44 |
+
|
45 |
+
`pinned`: _boolean_
|
46 |
+
Whether the Space stays on top of your list.
|
__pycache__/run_code.cpython-38.pyc
ADDED
Binary file (3.18 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
os.system(
|
4 |
+
"wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg")
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import requests
|
8 |
+
import torch
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.transforms.functional import InterpolationMode
|
11 |
+
|
12 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
13 |
+
|
14 |
+
# MDETR Code
|
15 |
+
import torchvision.transforms as T
|
16 |
+
import matplotlib.pyplot as plt
|
17 |
+
from collections import defaultdict
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import numpy as np
|
20 |
+
from skimage.measure import find_contours
|
21 |
+
|
22 |
+
from matplotlib import patches, lines
|
23 |
+
from matplotlib.patches import Polygon
|
24 |
+
import gradio as gr
|
25 |
+
|
26 |
+
torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg',
|
27 |
+
'elephant.jpg')
|
28 |
+
|
29 |
+
model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True,
|
30 |
+
return_postprocessor=True)
|
31 |
+
model2 = model2.cpu()
|
32 |
+
model2.eval()
|
33 |
+
|
34 |
+
torch.set_grad_enabled(False);
|
35 |
+
# standard PyTorch mean-std input image normalization
|
36 |
+
transform = T.Compose([
|
37 |
+
T.Resize(800),
|
38 |
+
T.ToTensor(),
|
39 |
+
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
40 |
+
])
|
41 |
+
|
42 |
+
|
43 |
+
# for output bounding box post-processing
|
44 |
+
def box_cxcywh_to_xyxy(x):
|
45 |
+
x_c, y_c, w, h = x.unbind(1)
|
46 |
+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
|
47 |
+
(x_c + 0.5 * w), (y_c + 0.5 * h)]
|
48 |
+
return torch.stack(b, dim=1)
|
49 |
+
|
50 |
+
|
51 |
+
def rescale_bboxes(out_bbox, size):
|
52 |
+
img_w, img_h = size
|
53 |
+
b = box_cxcywh_to_xyxy(out_bbox)
|
54 |
+
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
55 |
+
return b
|
56 |
+
|
57 |
+
|
58 |
+
# colors for visualization
|
59 |
+
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
|
60 |
+
[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]
|
61 |
+
|
62 |
+
|
63 |
+
def apply_mask(image, mask, color, alpha=0.5):
|
64 |
+
"""Apply the given mask to the image.
|
65 |
+
"""
|
66 |
+
for c in range(3):
|
67 |
+
image[:, :, c] = np.where(mask == 1,
|
68 |
+
image[:, :, c] *
|
69 |
+
(1 - alpha) + alpha * color[c] * 255,
|
70 |
+
image[:, :, c])
|
71 |
+
return image
|
72 |
+
|
73 |
+
|
74 |
+
def plot_results(pil_img, scores, boxes, labels, masks=None):
|
75 |
+
plt.figure(figsize=(16, 10))
|
76 |
+
np_image = np.array(pil_img)
|
77 |
+
ax = plt.gca()
|
78 |
+
colors = COLORS * 100
|
79 |
+
if masks is None:
|
80 |
+
masks = [None for _ in range(len(scores))]
|
81 |
+
assert len(scores) == len(boxes) == len(labels) == len(masks)
|
82 |
+
for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):
|
83 |
+
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
|
84 |
+
fill=False, color=c, linewidth=3))
|
85 |
+
text = f'{l}: {s:0.2f}'
|
86 |
+
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
|
87 |
+
|
88 |
+
if mask is None:
|
89 |
+
continue
|
90 |
+
np_image = apply_mask(np_image, mask, c)
|
91 |
+
|
92 |
+
padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)
|
93 |
+
padded_mask[1:-1, 1:-1] = mask
|
94 |
+
contours = find_contours(padded_mask, 0.5)
|
95 |
+
for verts in contours:
|
96 |
+
# Subtract the padding and flip (y, x) to (x, y)
|
97 |
+
verts = np.fliplr(verts) - 1
|
98 |
+
p = Polygon(verts, facecolor="none", edgecolor=c)
|
99 |
+
ax.add_patch(p)
|
100 |
+
|
101 |
+
plt.imshow(np_image)
|
102 |
+
plt.axis('off')
|
103 |
+
plt.savefig('foo.png', bbox_inches='tight')
|
104 |
+
return 'foo.png'
|
105 |
+
|
106 |
+
|
107 |
+
def add_res(results, ax, color='green'):
|
108 |
+
# for tt in results.values():
|
109 |
+
if True:
|
110 |
+
bboxes = results['boxes']
|
111 |
+
labels = results['labels']
|
112 |
+
scores = results['scores']
|
113 |
+
# keep = scores >= 0.0
|
114 |
+
# bboxes = bboxes[keep].tolist()
|
115 |
+
# labels = labels[keep].tolist()
|
116 |
+
# scores = scores[keep].tolist()
|
117 |
+
# print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))
|
118 |
+
|
119 |
+
colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']
|
120 |
+
|
121 |
+
for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):
|
122 |
+
ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))
|
123 |
+
cls_name = ll if isinstance(ll, str) else CLASSES[ll]
|
124 |
+
text = f'{cls_name}: {ss:.2f}'
|
125 |
+
print(text)
|
126 |
+
ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))
|
127 |
+
|
128 |
+
|
129 |
+
def plot_inference(im, caption, approaches):
|
130 |
+
choices = {"Worker Helmet Separately": 1, "Worker Helmet Vest": 2, "Workers only": 3}
|
131 |
+
|
132 |
+
# mean-std normalize the input image (batch-size: 1)
|
133 |
+
img = transform(im).unsqueeze(0).cpu()
|
134 |
+
|
135 |
+
# propagate through the model
|
136 |
+
memory_cache = model2(img, [caption], encode_and_save=True)
|
137 |
+
outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache)
|
138 |
+
|
139 |
+
# keep only predictions with 0.7+ confidence
|
140 |
+
probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()
|
141 |
+
keep = (probas > 0.7).cpu()
|
142 |
+
|
143 |
+
# convert boxes from [0; 1] to image scales
|
144 |
+
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)
|
145 |
+
|
146 |
+
# Extract the text spans predicted by each box
|
147 |
+
positive_tokens = (outputs["pred_logits"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()
|
148 |
+
predicted_spans = defaultdict(str)
|
149 |
+
for tok in positive_tokens:
|
150 |
+
item, pos = tok
|
151 |
+
if pos < 255:
|
152 |
+
span = memory_cache["tokenized"].token_to_chars(0, pos)
|
153 |
+
predicted_spans[item] += " " + caption[span.start:span.end]
|
154 |
+
|
155 |
+
labels = [predicted_spans[k] for k in sorted(list(predicted_spans.keys()))]
|
156 |
+
caption = 'Caption: ' + caption
|
157 |
+
return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches]))
|
158 |
+
|
159 |
+
|
160 |
+
# BLIP Code
|
161 |
+
|
162 |
+
|
163 |
+
from modelsn.blip import blip_decoder
|
164 |
+
|
165 |
+
image_size = 384
|
166 |
+
transform = transforms.Compose([
|
167 |
+
transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
|
168 |
+
transforms.ToTensor(),
|
169 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
170 |
+
])
|
171 |
+
|
172 |
+
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
173 |
+
|
174 |
+
model = blip_decoder(pretrained=model_url, image_size=384, vit='base')
|
175 |
+
model.eval()
|
176 |
+
model = model.to(device)
|
177 |
+
|
178 |
+
from modelsn.blip_vqa import blip_vqa
|
179 |
+
|
180 |
+
image_size_vq = 480
|
181 |
+
transform_vq = transforms.Compose([
|
182 |
+
transforms.Resize((image_size_vq, image_size_vq), interpolation=InterpolationMode.BICUBIC),
|
183 |
+
transforms.ToTensor(),
|
184 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
185 |
+
])
|
186 |
+
|
187 |
+
model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
|
188 |
+
|
189 |
+
model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')
|
190 |
+
model_vq.eval()
|
191 |
+
model_vq = model_vq.to(device)
|
192 |
+
|
193 |
+
|
194 |
+
def inference(raw_image, approaches, question):
|
195 |
+
image = transform(raw_image).unsqueeze(0).to(device)
|
196 |
+
with torch.no_grad():
|
197 |
+
caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)
|
198 |
+
|
199 |
+
return (plot_inference(raw_image, caption[0], approaches))
|
200 |
+
# return 'caption: '+caption[0]
|
201 |
+
|
202 |
+
|
203 |
+
# PPE Detection code
|
204 |
+
import numpy as np
|
205 |
+
import run_code
|
206 |
+
import gradio as gr
|
207 |
+
|
208 |
+
|
209 |
+
def sepia_call(caption, Input_Image, MDETR_im, Approach):
|
210 |
+
pil_image = Input_Image
|
211 |
+
open_cv_image = np.asarray(pil_image)
|
212 |
+
sepia_img = run_code.run(open_cv_image, Approach)
|
213 |
+
images = sepia_img['img']
|
214 |
+
texts = sepia_img['text']
|
215 |
+
|
216 |
+
return (caption, MDETR_im, images, texts)
|
217 |
+
|
218 |
+
|
219 |
+
inputs = [gr.inputs.Image(type='pil'),
|
220 |
+
gr.inputs.Radio(choices=["Worker Helmet Separately", "Worker Helmet Vest", "Workers only"], type="value",
|
221 |
+
default="Worker Helmet Vest", label="Model"), "textbox"]
|
222 |
+
outputs = [gr.outputs.Textbox(label="Output"), "image", "image", gr.outputs.Textbox(label="Output")]
|
223 |
+
|
224 |
+
title = "BLIP + MDETR + PPE Detection"
|
225 |
+
|
226 |
+
description = "Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
227 |
+
|
228 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>"
|
229 |
+
|
230 |
+
gr.Interface(inference, inputs, outputs, title=title, description=description, article=article,
|
231 |
+
examples=[['starry.jpg', "Image Captioning", "None"]]).launch(share=True, enable_queue=True,
|
232 |
+
cache_examples=False)
|
app_run.ipynb
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "15468c81",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"--2022-02-15 18:26:17-- https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\n",
|
14 |
+
"Resolving upload.wikimedia.org (upload.wikimedia.org)... 91.198.174.208\n",
|
15 |
+
"Connecting to upload.wikimedia.org (upload.wikimedia.org)|91.198.174.208|:443... connected.\n",
|
16 |
+
"HTTP request sent, awaiting response... 200 OK\n",
|
17 |
+
"Length: 1388211 (1.3M) [image/jpeg]\n",
|
18 |
+
"Saving to: ‘starry.jpg’\n",
|
19 |
+
"\n",
|
20 |
+
" 0K .......... .......... .......... .......... .......... 3% 776K 2s\n",
|
21 |
+
" 50K .......... .......... .......... .......... .......... 7% 877K 2s\n",
|
22 |
+
" 100K .......... .......... .......... .......... .......... 11% 2.93M 1s\n",
|
23 |
+
" 150K .......... .......... .......... .......... .......... 14% 2.28M 1s\n",
|
24 |
+
" 200K .......... .......... .......... .......... .......... 18% 4.04M 1s\n",
|
25 |
+
" 250K .......... .......... .......... .......... .......... 22% 5.46M 1s\n",
|
26 |
+
" 300K .......... .......... .......... .......... .......... 25% 6.40M 1s\n",
|
27 |
+
" 350K .......... .......... .......... .......... .......... 29% 2.41M 0s\n",
|
28 |
+
" 400K .......... .......... .......... .......... .......... 33% 3.18M 0s\n",
|
29 |
+
" 450K .......... .......... .......... .......... .......... 36% 3.03M 0s\n",
|
30 |
+
" 500K .......... .......... .......... .......... .......... 40% 8.30M 0s\n",
|
31 |
+
" 550K .......... .......... .......... .......... .......... 44% 3.31M 0s\n",
|
32 |
+
" 600K .......... .......... .......... .......... .......... 47% 3.10M 0s\n",
|
33 |
+
" 650K .......... .......... .......... .......... .......... 51% 12.3M 0s\n",
|
34 |
+
" 700K .......... .......... .......... .......... .......... 55% 4.20M 0s\n",
|
35 |
+
" 750K .......... .......... .......... .......... .......... 59% 1.93M 0s\n",
|
36 |
+
" 800K .......... .......... .......... .......... .......... 62% 6.28M 0s\n",
|
37 |
+
" 850K .......... .......... .......... .......... .......... 66% 3.09M 0s\n",
|
38 |
+
" 900K .......... .......... .......... .......... .......... 70% 22.7M 0s\n",
|
39 |
+
" 950K .......... .......... .......... .......... .......... 73% 4.43M 0s\n",
|
40 |
+
" 1000K .......... .......... .......... .......... .......... 77% 4.16M 0s\n",
|
41 |
+
" 1050K .......... .......... .......... .......... .......... 81% 2.29M 0s\n",
|
42 |
+
" 1100K .......... .......... .......... .......... .......... 84% 1.81M 0s\n",
|
43 |
+
" 1150K .......... .......... .......... .......... .......... 88% 6.20M 0s\n",
|
44 |
+
" 1200K .......... .......... .......... .......... .......... 92% 2.03M 0s\n",
|
45 |
+
" 1250K .......... .......... .......... .......... .......... 95% 23.5M 0s\n",
|
46 |
+
" 1300K .......... .......... .......... .......... .......... 99% 5.04M 0s\n",
|
47 |
+
" 1350K ..... 100% 9.95M=0.5s\n",
|
48 |
+
"\n",
|
49 |
+
"2022-02-15 18:26:17 (2.89 MB/s) - ‘starry.jpg’ saved [1388211/1388211]\n",
|
50 |
+
"\n"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"data": {
|
55 |
+
"application/vnd.jupyter.widget-view+json": {
|
56 |
+
"model_id": "02b7655f0b2b404b952b7c152a3a1661",
|
57 |
+
"version_major": 2,
|
58 |
+
"version_minor": 0
|
59 |
+
},
|
60 |
+
"text/plain": [
|
61 |
+
" 0%| | 0.00/262k [00:00<?, ?B/s]"
|
62 |
+
]
|
63 |
+
},
|
64 |
+
"metadata": {},
|
65 |
+
"output_type": "display_data"
|
66 |
+
},
|
67 |
+
{
|
68 |
+
"name": "stderr",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"Using cache found in /Users/sanjaykamath/.cache/torch/hub/ashkamath_mdetr_main\n",
|
72 |
+
"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.bias', 'lm_head.dense.weight']\n",
|
73 |
+
"- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
74 |
+
"- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
75 |
+
]
|
76 |
+
},
|
77 |
+
{
|
78 |
+
"name": "stdout",
|
79 |
+
"output_type": "stream",
|
80 |
+
"text": [
|
81 |
+
"load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n",
|
82 |
+
"load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth\n",
|
83 |
+
"Running on local URL: http://127.0.0.1:7862/\n",
|
84 |
+
"Running on public URL: https://13389.gradio.app\n",
|
85 |
+
"\n",
|
86 |
+
"This share link expires in 72 hours. For free permanent hosting, check out Spaces (https://huggingface.co/spaces)\n"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"data": {
|
91 |
+
"text/html": [
|
92 |
+
"\n",
|
93 |
+
" <iframe\n",
|
94 |
+
" width=\"900\"\n",
|
95 |
+
" height=\"500\"\n",
|
96 |
+
" src=\"https://13389.gradio.app\"\n",
|
97 |
+
" frameborder=\"0\"\n",
|
98 |
+
" allowfullscreen\n",
|
99 |
+
" \n",
|
100 |
+
" ></iframe>\n",
|
101 |
+
" "
|
102 |
+
],
|
103 |
+
"text/plain": [
|
104 |
+
"<IPython.lib.display.IFrame at 0x7fce90855f40>"
|
105 |
+
]
|
106 |
+
},
|
107 |
+
"metadata": {},
|
108 |
+
"output_type": "display_data"
|
109 |
+
},
|
110 |
+
{
|
111 |
+
"data": {
|
112 |
+
"text/plain": [
|
113 |
+
"(<fastapi.applications.FastAPI at 0x7fcfa3376fd0>,\n",
|
114 |
+
" 'http://127.0.0.1:7862/',\n",
|
115 |
+
" 'https://13389.gradio.app')"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
"execution_count": 3,
|
119 |
+
"metadata": {},
|
120 |
+
"output_type": "execute_result"
|
121 |
+
},
|
122 |
+
{
|
123 |
+
"name": "stderr",
|
124 |
+
"output_type": "stream",
|
125 |
+
"text": [
|
126 |
+
"2022-02-15 18:27:19.011924: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
127 |
+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
128 |
+
]
|
129 |
+
}
|
130 |
+
],
|
131 |
+
"source": [
|
132 |
+
"import os\n",
|
133 |
+
"os.system(\"wget https://upload.wikimedia.org/wikipedia/commons/thumb/e/ea/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg/1920px-Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg -O starry.jpg\")\n",
|
134 |
+
"\n",
|
135 |
+
"from PIL import Image\n",
|
136 |
+
"import requests\n",
|
137 |
+
"import torch\n",
|
138 |
+
"from torchvision import transforms\n",
|
139 |
+
"from torchvision.transforms.functional import InterpolationMode\n",
|
140 |
+
"\n",
|
141 |
+
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
142 |
+
"\n",
|
143 |
+
"\n",
|
144 |
+
"\n",
|
145 |
+
" \n",
|
146 |
+
"#MDETR Code \n",
|
147 |
+
"import torchvision.transforms as T\n",
|
148 |
+
"import matplotlib.pyplot as plt\n",
|
149 |
+
"from collections import defaultdict\n",
|
150 |
+
"import torch.nn.functional as F\n",
|
151 |
+
"import numpy as np\n",
|
152 |
+
"from skimage.measure import find_contours\n",
|
153 |
+
"\n",
|
154 |
+
"from matplotlib import patches, lines\n",
|
155 |
+
"from matplotlib.patches import Polygon\n",
|
156 |
+
"import gradio as gr\n",
|
157 |
+
"\n",
|
158 |
+
"torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2014/03/04/15/10/elephants-279505_1280.jpg', 'elephant.jpg')\n",
|
159 |
+
"\n",
|
160 |
+
"\n",
|
161 |
+
"model2, postprocessor = torch.hub.load('ashkamath/mdetr:main', 'mdetr_efficientnetB5', pretrained=True, return_postprocessor=True)\n",
|
162 |
+
"model2 = model2.cpu()\n",
|
163 |
+
"model2.eval()\n",
|
164 |
+
"\n",
|
165 |
+
"\n",
|
166 |
+
"\n",
|
167 |
+
"\n",
|
168 |
+
"torch.set_grad_enabled(False);\n",
|
169 |
+
"# standard PyTorch mean-std input image normalization\n",
|
170 |
+
"transform = T.Compose([\n",
|
171 |
+
" T.Resize(800),\n",
|
172 |
+
" T.ToTensor(),\n",
|
173 |
+
" T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
|
174 |
+
"])\n",
|
175 |
+
"\n",
|
176 |
+
"# for output bounding box post-processing\n",
|
177 |
+
"def box_cxcywh_to_xyxy(x):\n",
|
178 |
+
" x_c, y_c, w, h = x.unbind(1)\n",
|
179 |
+
" b = [(x_c - 0.5 * w), (y_c - 0.5 * h),\n",
|
180 |
+
" (x_c + 0.5 * w), (y_c + 0.5 * h)]\n",
|
181 |
+
" return torch.stack(b, dim=1)\n",
|
182 |
+
"\n",
|
183 |
+
"def rescale_bboxes(out_bbox, size):\n",
|
184 |
+
" img_w, img_h = size\n",
|
185 |
+
" b = box_cxcywh_to_xyxy(out_bbox)\n",
|
186 |
+
" b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)\n",
|
187 |
+
" return b\n",
|
188 |
+
"# colors for visualization\n",
|
189 |
+
"COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],\n",
|
190 |
+
" [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]\n",
|
191 |
+
"\n",
|
192 |
+
"def apply_mask(image, mask, color, alpha=0.5):\n",
|
193 |
+
" \"\"\"Apply the given mask to the image.\n",
|
194 |
+
" \"\"\"\n",
|
195 |
+
" for c in range(3):\n",
|
196 |
+
" image[:, :, c] = np.where(mask == 1,\n",
|
197 |
+
" image[:, :, c] *\n",
|
198 |
+
" (1 - alpha) + alpha * color[c] * 255,\n",
|
199 |
+
" image[:, :, c])\n",
|
200 |
+
" return image\n",
|
201 |
+
"\n",
|
202 |
+
"def plot_results(pil_img, scores, boxes, labels, masks=None):\n",
|
203 |
+
" plt.figure(figsize=(16,10))\n",
|
204 |
+
" np_image = np.array(pil_img)\n",
|
205 |
+
" ax = plt.gca()\n",
|
206 |
+
" colors = COLORS * 100\n",
|
207 |
+
" if masks is None:\n",
|
208 |
+
" masks = [None for _ in range(len(scores))]\n",
|
209 |
+
" assert len(scores) == len(boxes) == len(labels) == len(masks)\n",
|
210 |
+
" for s, (xmin, ymin, xmax, ymax), l, mask, c in zip(scores, boxes.tolist(), labels, masks, colors):\n",
|
211 |
+
" ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,\n",
|
212 |
+
" fill=False, color=c, linewidth=3))\n",
|
213 |
+
" text = f'{l}: {s:0.2f}'\n",
|
214 |
+
" ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n",
|
215 |
+
"\n",
|
216 |
+
" if mask is None:\n",
|
217 |
+
" continue\n",
|
218 |
+
" np_image = apply_mask(np_image, mask, c)\n",
|
219 |
+
"\n",
|
220 |
+
" padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2), dtype=np.uint8)\n",
|
221 |
+
" padded_mask[1:-1, 1:-1] = mask\n",
|
222 |
+
" contours = find_contours(padded_mask, 0.5)\n",
|
223 |
+
" for verts in contours:\n",
|
224 |
+
" # Subtract the padding and flip (y, x) to (x, y)\n",
|
225 |
+
" verts = np.fliplr(verts) - 1\n",
|
226 |
+
" p = Polygon(verts, facecolor=\"none\", edgecolor=c)\n",
|
227 |
+
" ax.add_patch(p)\n",
|
228 |
+
"\n",
|
229 |
+
"\n",
|
230 |
+
" plt.imshow(np_image)\n",
|
231 |
+
" plt.axis('off')\n",
|
232 |
+
" plt.savefig('foo.png',bbox_inches='tight')\n",
|
233 |
+
" return 'foo.png'\n",
|
234 |
+
"\n",
|
235 |
+
"\n",
|
236 |
+
"def add_res(results, ax, color='green'):\n",
|
237 |
+
" #for tt in results.values():\n",
|
238 |
+
" if True:\n",
|
239 |
+
" bboxes = results['boxes']\n",
|
240 |
+
" labels = results['labels']\n",
|
241 |
+
" scores = results['scores']\n",
|
242 |
+
" #keep = scores >= 0.0\n",
|
243 |
+
" #bboxes = bboxes[keep].tolist()\n",
|
244 |
+
" #labels = labels[keep].tolist()\n",
|
245 |
+
" #scores = scores[keep].tolist()\n",
|
246 |
+
" #print(torchvision.ops.box_iou(tt['boxes'].cpu().detach(), torch.as_tensor([[xmin, ymin, xmax, ymax]])))\n",
|
247 |
+
" \n",
|
248 |
+
" colors = ['purple', 'yellow', 'red', 'green', 'orange', 'pink']\n",
|
249 |
+
" \n",
|
250 |
+
" for i, (b, ll, ss) in enumerate(zip(bboxes, labels, scores)):\n",
|
251 |
+
" ax.add_patch(plt.Rectangle((b[0], b[1]), b[2] - b[0], b[3] - b[1], fill=False, color=colors[i], linewidth=3))\n",
|
252 |
+
" cls_name = ll if isinstance(ll,str) else CLASSES[ll]\n",
|
253 |
+
" text = f'{cls_name}: {ss:.2f}'\n",
|
254 |
+
" print(text)\n",
|
255 |
+
" ax.text(b[0], b[1], text, fontsize=15, bbox=dict(facecolor='white', alpha=0.8))\n",
|
256 |
+
"\n",
|
257 |
+
"\n",
|
258 |
+
"def plot_inference(im, caption, approaches):\n",
|
259 |
+
" \n",
|
260 |
+
" choices = {\"Worker Helmet Separately\" : 1,\"Worker Helmet Vest\":2, \"Workers only\":3}\n",
|
261 |
+
" \n",
|
262 |
+
" \n",
|
263 |
+
"# mean-std normalize the input image (batch-size: 1)\n",
|
264 |
+
" img = transform(im).unsqueeze(0).cpu()\n",
|
265 |
+
"\n",
|
266 |
+
" # propagate through the model\n",
|
267 |
+
" memory_cache = model2(img, [caption], encode_and_save=True)\n",
|
268 |
+
" outputs = model2(img, [caption], encode_and_save=False, memory_cache=memory_cache)\n",
|
269 |
+
"\n",
|
270 |
+
" # keep only predictions with 0.7+ confidence\n",
|
271 |
+
" probas = 1 - outputs['pred_logits'].softmax(-1)[0, :, -1].cpu()\n",
|
272 |
+
" keep = (probas > 0.7).cpu()\n",
|
273 |
+
"\n",
|
274 |
+
" # convert boxes from [0; 1] to image scales\n",
|
275 |
+
" bboxes_scaled = rescale_bboxes(outputs['pred_boxes'].cpu()[0, keep], im.size)\n",
|
276 |
+
"\n",
|
277 |
+
" # Extract the text spans predicted by each box\n",
|
278 |
+
" positive_tokens = (outputs[\"pred_logits\"].cpu()[0, keep].softmax(-1) > 0.1).nonzero().tolist()\n",
|
279 |
+
" predicted_spans = defaultdict(str)\n",
|
280 |
+
" for tok in positive_tokens:\n",
|
281 |
+
" item, pos = tok\n",
|
282 |
+
" if pos < 255:\n",
|
283 |
+
" span = memory_cache[\"tokenized\"].token_to_chars(0, pos)\n",
|
284 |
+
" predicted_spans [item] += \" \" + caption[span.start:span.end]\n",
|
285 |
+
"\n",
|
286 |
+
" labels = [predicted_spans [k] for k in sorted(list(predicted_spans .keys()))]\n",
|
287 |
+
" caption = 'Caption: '+ caption\n",
|
288 |
+
" return (sepia_call(caption, im, plot_results(im, probas[keep], bboxes_scaled, labels), choices[approaches]))\n",
|
289 |
+
" \n",
|
290 |
+
"\n",
|
291 |
+
"\n",
|
292 |
+
" \n",
|
293 |
+
"#BLIP Code\n",
|
294 |
+
"\n",
|
295 |
+
"\n",
|
296 |
+
"from modelsn.blip import blip_decoder\n",
|
297 |
+
"\n",
|
298 |
+
"image_size = 384\n",
|
299 |
+
"transform = transforms.Compose([\n",
|
300 |
+
" transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),\n",
|
301 |
+
" transforms.ToTensor(),\n",
|
302 |
+
" transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
|
303 |
+
" ]) \n",
|
304 |
+
"\n",
|
305 |
+
"model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
|
306 |
+
" \n",
|
307 |
+
"model = blip_decoder(pretrained=model_url, image_size=384, vit='base')\n",
|
308 |
+
"model.eval()\n",
|
309 |
+
"model = model.to(device)\n",
|
310 |
+
"\n",
|
311 |
+
"\n",
|
312 |
+
"from modelsn.blip_vqa import blip_vqa\n",
|
313 |
+
"\n",
|
314 |
+
"image_size_vq = 480\n",
|
315 |
+
"transform_vq = transforms.Compose([\n",
|
316 |
+
" transforms.Resize((image_size_vq,image_size_vq),interpolation=InterpolationMode.BICUBIC),\n",
|
317 |
+
" transforms.ToTensor(),\n",
|
318 |
+
" transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))\n",
|
319 |
+
" ]) \n",
|
320 |
+
"\n",
|
321 |
+
"model_url_vq = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'\n",
|
322 |
+
" \n",
|
323 |
+
"model_vq = blip_vqa(pretrained=model_url_vq, image_size=480, vit='base')\n",
|
324 |
+
"model_vq.eval()\n",
|
325 |
+
"model_vq = model_vq.to(device)\n",
|
326 |
+
"\n",
|
327 |
+
"\n",
|
328 |
+
"\n",
|
329 |
+
"def inference(raw_image, approaches, question):\n",
|
330 |
+
" \n",
|
331 |
+
"\n",
|
332 |
+
" image = transform(raw_image).unsqueeze(0).to(device) \n",
|
333 |
+
" with torch.no_grad():\n",
|
334 |
+
" caption = model.generate(image, sample=False, num_beams=3, max_length=20, min_length=5)\n",
|
335 |
+
"\n",
|
336 |
+
" return (plot_inference(raw_image, caption[0], approaches))\n",
|
337 |
+
" #return 'caption: '+caption[0]\n",
|
338 |
+
"\n",
|
339 |
+
" \n",
|
340 |
+
"\n",
|
341 |
+
" \n",
|
342 |
+
"#PPE Detection code\n",
|
343 |
+
"import numpy as np\n",
|
344 |
+
"import run_code\n",
|
345 |
+
"import gradio as gr\n",
|
346 |
+
" \n",
|
347 |
+
"\n",
|
348 |
+
"def sepia_call(caption, Input_Image, MDETR_im, Approach):\n",
|
349 |
+
" pil_image = Input_Image\n",
|
350 |
+
" open_cv_image = np.asarray(pil_image)\n",
|
351 |
+
" sepia_img = run_code.run(open_cv_image, Approach)\n",
|
352 |
+
" images = sepia_img['img']\n",
|
353 |
+
" texts= sepia_img['text']\n",
|
354 |
+
"\n",
|
355 |
+
" return (caption, MDETR_im, images, texts)\n",
|
356 |
+
"\n",
|
357 |
+
"\n",
|
358 |
+
"inputs = [gr.inputs.Image(type='pil'),gr.inputs.Radio(choices=[\"Worker Helmet Separately\",\"Worker Helmet Vest\", \"Workers only\"], type=\"value\", default=\"Worker Helmet Vest\", label=\"Model\"),\"textbox\"]\n",
|
359 |
+
"outputs = [gr.outputs.Textbox(label=\"Output\"), \"image\", \"image\", gr.outputs.Textbox(label=\"Output\")]\n",
|
360 |
+
"\n",
|
361 |
+
"\n",
|
362 |
+
"title = \"BLIP + MDETR + PPE Detection\"\n",
|
363 |
+
"\n",
|
364 |
+
"description = \"Gradio demo for BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation by Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below.\"\n",
|
365 |
+
"\n",
|
366 |
+
"article = \"<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation</a> | <a href='https://github.com/salesforce/BLIP' target='_blank'>Github Repo</a></p>\"\n",
|
367 |
+
"\n",
|
368 |
+
"\n",
|
369 |
+
"gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=[['starry.jpg',\"Image Captioning\",\"None\"]]).launch(share=True,enable_queue=True,cache_examples=False)"
|
370 |
+
]
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"cell_type": "raw",
|
374 |
+
"id": "b2729aa9",
|
375 |
+
"metadata": {},
|
376 |
+
"source": []
|
377 |
+
}
|
378 |
+
],
|
379 |
+
"metadata": {
|
380 |
+
"kernelspec": {
|
381 |
+
"display_name": "Python 3 (ipykernel)",
|
382 |
+
"language": "python",
|
383 |
+
"name": "python3"
|
384 |
+
},
|
385 |
+
"language_info": {
|
386 |
+
"codemirror_mode": {
|
387 |
+
"name": "ipython",
|
388 |
+
"version": 3
|
389 |
+
},
|
390 |
+
"file_extension": ".py",
|
391 |
+
"mimetype": "text/x-python",
|
392 |
+
"name": "python",
|
393 |
+
"nbconvert_exporter": "python",
|
394 |
+
"pygments_lexer": "ipython3",
|
395 |
+
"version": "3.8.12"
|
396 |
+
}
|
397 |
+
},
|
398 |
+
"nbformat": 4,
|
399 |
+
"nbformat_minor": 5
|
400 |
+
}
|
configs/caption_coco.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
coco_gt_root: 'annotation/coco_gt'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
vit: 'base'
|
10 |
+
vit_grad_ckpt: False
|
11 |
+
vit_ckpt_layer: 0
|
12 |
+
batch_size: 32
|
13 |
+
init_lr: 1e-5
|
14 |
+
|
15 |
+
# vit: 'large'
|
16 |
+
# vit_grad_ckpt: True
|
17 |
+
# vit_ckpt_layer: 5
|
18 |
+
# batch_size: 16
|
19 |
+
# init_lr: 2e-6
|
20 |
+
|
21 |
+
image_size: 384
|
22 |
+
|
23 |
+
# generation configs
|
24 |
+
max_length: 20
|
25 |
+
min_length: 5
|
26 |
+
num_beams: 3
|
27 |
+
prompt: 'a picture of '
|
28 |
+
|
29 |
+
# optimizer
|
30 |
+
weight_decay: 0.05
|
31 |
+
min_lr: 0
|
32 |
+
max_epoch: 5
|
33 |
+
|
configs/med_config.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertModel"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30524,
|
19 |
+
"encoder_width": 768,
|
20 |
+
"add_cross_attention": true
|
21 |
+
}
|
configs/nlvr.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/NLVR2/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_nlvr.pth'
|
6 |
+
|
7 |
+
#size of vit model; base or large
|
8 |
+
vit: 'base'
|
9 |
+
batch_size_train: 16
|
10 |
+
batch_size_test: 64
|
11 |
+
vit_grad_ckpt: False
|
12 |
+
vit_ckpt_layer: 0
|
13 |
+
max_epoch: 15
|
14 |
+
|
15 |
+
image_size: 384
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-5
|
20 |
+
min_lr: 0
|
21 |
+
|
configs/nocaps.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/nocaps/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
|
4 |
+
# set pretrained as a file path or an url
|
5 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'
|
6 |
+
|
7 |
+
vit: 'base'
|
8 |
+
batch_size: 32
|
9 |
+
|
10 |
+
image_size: 384
|
11 |
+
|
12 |
+
max_length: 20
|
13 |
+
min_length: 5
|
14 |
+
num_beams: 3
|
15 |
+
prompt: 'a picture of '
|
configs/pretrain.yaml
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
train_file: ['/export/share/junnan-li/VL_pretrain/annotation/coco_karpathy_train.json',
|
2 |
+
'/export/share/junnan-li/VL_pretrain/annotation/vg_caption.json',
|
3 |
+
]
|
4 |
+
laion_path: ''
|
5 |
+
|
6 |
+
# size of vit model; base or large
|
7 |
+
vit: 'base'
|
8 |
+
vit_grad_ckpt: False
|
9 |
+
vit_ckpt_layer: 0
|
10 |
+
|
11 |
+
image_size: 224
|
12 |
+
batch_size: 75
|
13 |
+
|
14 |
+
queue_size: 57600
|
15 |
+
alpha: 0.4
|
16 |
+
|
17 |
+
# optimizer
|
18 |
+
weight_decay: 0.05
|
19 |
+
init_lr: 3e-4
|
20 |
+
min_lr: 1e-6
|
21 |
+
warmup_lr: 1e-6
|
22 |
+
lr_decay_rate: 0.9
|
23 |
+
max_epoch: 20
|
24 |
+
warmup_steps: 3000
|
25 |
+
|
26 |
+
|
27 |
+
|
configs/retrieval_coco.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/coco/images/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'coco'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_coco.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 12
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 256
|
28 |
+
negative_all_rank: True
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/retrieval_flickr.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
image_root: '/export/share/datasets/vision/flickr30k/'
|
2 |
+
ann_root: 'annotation'
|
3 |
+
dataset: 'flickr'
|
4 |
+
|
5 |
+
# set pretrained as a file path or an url
|
6 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_retrieval_flickr.pth'
|
7 |
+
|
8 |
+
# size of vit model; base or large
|
9 |
+
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 32
|
12 |
+
batch_size_test: 64
|
13 |
+
vit_grad_ckpt: True
|
14 |
+
vit_ckpt_layer: 4
|
15 |
+
init_lr: 1e-5
|
16 |
+
|
17 |
+
# vit: 'large'
|
18 |
+
# batch_size_train: 16
|
19 |
+
# batch_size_test: 32
|
20 |
+
# vit_grad_ckpt: True
|
21 |
+
# vit_ckpt_layer: 10
|
22 |
+
# init_lr: 5e-6
|
23 |
+
|
24 |
+
image_size: 384
|
25 |
+
queue_size: 57600
|
26 |
+
alpha: 0.4
|
27 |
+
k_test: 128
|
28 |
+
negative_all_rank: False
|
29 |
+
|
30 |
+
# optimizer
|
31 |
+
weight_decay: 0.05
|
32 |
+
min_lr: 0
|
33 |
+
max_epoch: 6
|
34 |
+
|
configs/vqa.yaml
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
vqa_root: '/export/share/datasets/vision/VQA/Images/mscoco/' #followed by train2014/
|
2 |
+
vg_root: '/export/share/datasets/vision/visual-genome/' #followed by image/
|
3 |
+
train_files: ['vqa_train','vqa_val','vg_qa']
|
4 |
+
ann_root: 'annotation'
|
5 |
+
|
6 |
+
# set pretrained as a file path or an url
|
7 |
+
pretrained: 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_vqa.pth'
|
8 |
+
|
9 |
+
# size of vit model; base or large
|
10 |
+
vit: 'base'
|
11 |
+
batch_size_train: 16
|
12 |
+
batch_size_test: 32
|
13 |
+
vit_grad_ckpt: False
|
14 |
+
vit_ckpt_layer: 0
|
15 |
+
init_lr: 2e-5
|
16 |
+
|
17 |
+
image_size: 480
|
18 |
+
|
19 |
+
k_test: 128
|
20 |
+
inference: 'rank'
|
21 |
+
|
22 |
+
# optimizer
|
23 |
+
weight_decay: 0.05
|
24 |
+
min_lr: 0
|
25 |
+
max_epoch: 10
|
data/__init__.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torchvision import transforms
|
4 |
+
from torchvision.transforms.functional import InterpolationMode
|
5 |
+
|
6 |
+
from data.coco_karpathy_dataset import coco_karpathy_train, coco_karpathy_caption_eval, coco_karpathy_retrieval_eval
|
7 |
+
from data.nocaps_dataset import nocaps_eval
|
8 |
+
from data.flickr30k_dataset import flickr30k_train, flickr30k_retrieval_eval
|
9 |
+
from data.vqa_dataset import vqa_dataset
|
10 |
+
from data.nlvr_dataset import nlvr_dataset
|
11 |
+
from data.pretrain_dataset import pretrain_dataset
|
12 |
+
from transform.randaugment import RandomAugment
|
13 |
+
|
14 |
+
def create_dataset(dataset, config, min_scale=0.5):
|
15 |
+
|
16 |
+
normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
17 |
+
|
18 |
+
transform_train = transforms.Compose([
|
19 |
+
transforms.RandomResizedCrop(config['image_size'],scale=(min_scale, 1.0),interpolation=InterpolationMode.BICUBIC),
|
20 |
+
transforms.RandomHorizontalFlip(),
|
21 |
+
RandomAugment(2,5,isPIL=True,augs=['Identity','AutoContrast','Brightness','Sharpness','Equalize',
|
22 |
+
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']),
|
23 |
+
transforms.ToTensor(),
|
24 |
+
normalize,
|
25 |
+
])
|
26 |
+
transform_test = transforms.Compose([
|
27 |
+
transforms.Resize((config['image_size'],config['image_size']),interpolation=InterpolationMode.BICUBIC),
|
28 |
+
transforms.ToTensor(),
|
29 |
+
normalize,
|
30 |
+
])
|
31 |
+
|
32 |
+
if dataset=='pretrain':
|
33 |
+
dataset = pretrain_dataset(config['train_file'], config['laion_path'], transform_train)
|
34 |
+
return dataset
|
35 |
+
|
36 |
+
elif dataset=='caption_coco':
|
37 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'], prompt=config['prompt'])
|
38 |
+
val_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
39 |
+
test_dataset = coco_karpathy_caption_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
40 |
+
return train_dataset, val_dataset, test_dataset
|
41 |
+
|
42 |
+
elif dataset=='nocaps':
|
43 |
+
val_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
44 |
+
test_dataset = nocaps_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
45 |
+
return val_dataset, test_dataset
|
46 |
+
|
47 |
+
elif dataset=='retrieval_coco':
|
48 |
+
train_dataset = coco_karpathy_train(transform_train, config['image_root'], config['ann_root'])
|
49 |
+
val_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
50 |
+
test_dataset = coco_karpathy_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
51 |
+
return train_dataset, val_dataset, test_dataset
|
52 |
+
|
53 |
+
elif dataset=='retrieval_flickr':
|
54 |
+
train_dataset = flickr30k_train(transform_train, config['image_root'], config['ann_root'])
|
55 |
+
val_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'val')
|
56 |
+
test_dataset = flickr30k_retrieval_eval(transform_test, config['image_root'], config['ann_root'], 'test')
|
57 |
+
return train_dataset, val_dataset, test_dataset
|
58 |
+
|
59 |
+
elif dataset=='vqa':
|
60 |
+
train_dataset = vqa_dataset(transform_train, config['ann_root'], config['vqa_root'], config['vg_root'],
|
61 |
+
train_files = config['train_files'], split='train')
|
62 |
+
test_dataset = vqa_dataset(transform_test, config['ann_root'], config['vqa_root'], config['vg_root'], split='test')
|
63 |
+
return train_dataset, test_dataset
|
64 |
+
|
65 |
+
elif dataset=='nlvr':
|
66 |
+
train_dataset = nlvr_dataset(transform_train, config['image_root'], config['ann_root'],'train')
|
67 |
+
val_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'val')
|
68 |
+
test_dataset = nlvr_dataset(transform_test, config['image_root'], config['ann_root'],'test')
|
69 |
+
return train_dataset, val_dataset, test_dataset
|
70 |
+
|
71 |
+
|
72 |
+
def create_sampler(datasets, shuffles, num_tasks, global_rank):
|
73 |
+
samplers = []
|
74 |
+
for dataset,shuffle in zip(datasets,shuffles):
|
75 |
+
sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=num_tasks, rank=global_rank, shuffle=shuffle)
|
76 |
+
samplers.append(sampler)
|
77 |
+
return samplers
|
78 |
+
|
79 |
+
|
80 |
+
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
|
81 |
+
loaders = []
|
82 |
+
for dataset,sampler,bs,n_worker,is_train,collate_fn in zip(datasets,samplers,batch_size,num_workers,is_trains,collate_fns):
|
83 |
+
if is_train:
|
84 |
+
shuffle = (sampler is None)
|
85 |
+
drop_last = True
|
86 |
+
else:
|
87 |
+
shuffle = False
|
88 |
+
drop_last = False
|
89 |
+
loader = DataLoader(
|
90 |
+
dataset,
|
91 |
+
batch_size=bs,
|
92 |
+
num_workers=n_worker,
|
93 |
+
pin_memory=True,
|
94 |
+
sampler=sampler,
|
95 |
+
shuffle=shuffle,
|
96 |
+
collate_fn=collate_fn,
|
97 |
+
drop_last=drop_last,
|
98 |
+
)
|
99 |
+
loaders.append(loader)
|
100 |
+
return loaders
|
101 |
+
|
data/coco_karpathy_dataset.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class coco_karpathy_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_train.json'
|
18 |
+
filename = 'coco_karpathy_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class coco_karpathy_caption_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
61 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
def __len__(self):
|
70 |
+
return len(self.annotation)
|
71 |
+
|
72 |
+
def __getitem__(self, index):
|
73 |
+
|
74 |
+
ann = self.annotation[index]
|
75 |
+
|
76 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
77 |
+
image = Image.open(image_path).convert('RGB')
|
78 |
+
image = self.transform(image)
|
79 |
+
|
80 |
+
img_id = ann['image'].split('/')[-1].strip('.jpg').split('_')[-1]
|
81 |
+
|
82 |
+
return image, int(img_id)
|
83 |
+
|
84 |
+
|
85 |
+
class coco_karpathy_retrieval_eval(Dataset):
|
86 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
87 |
+
'''
|
88 |
+
image_root (string): Root directory of images (e.g. coco/images/)
|
89 |
+
ann_root (string): directory to store the annotation file
|
90 |
+
split (string): val or test
|
91 |
+
'''
|
92 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
|
93 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
|
94 |
+
filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
|
95 |
+
|
96 |
+
download_url(urls[split],ann_root)
|
97 |
+
|
98 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
99 |
+
self.transform = transform
|
100 |
+
self.image_root = image_root
|
101 |
+
|
102 |
+
self.text = []
|
103 |
+
self.image = []
|
104 |
+
self.txt2img = {}
|
105 |
+
self.img2txt = {}
|
106 |
+
|
107 |
+
txt_id = 0
|
108 |
+
for img_id, ann in enumerate(self.annotation):
|
109 |
+
self.image.append(ann['image'])
|
110 |
+
self.img2txt[img_id] = []
|
111 |
+
for i, caption in enumerate(ann['caption']):
|
112 |
+
self.text.append(pre_caption(caption,max_words))
|
113 |
+
self.img2txt[img_id].append(txt_id)
|
114 |
+
self.txt2img[txt_id] = img_id
|
115 |
+
txt_id += 1
|
116 |
+
|
117 |
+
def __len__(self):
|
118 |
+
return len(self.annotation)
|
119 |
+
|
120 |
+
def __getitem__(self, index):
|
121 |
+
|
122 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
123 |
+
image = Image.open(image_path).convert('RGB')
|
124 |
+
image = self.transform(image)
|
125 |
+
|
126 |
+
return image, index
|
data/flickr30k_dataset.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
from data.utils import pre_caption
|
10 |
+
|
11 |
+
class flickr30k_train(Dataset):
|
12 |
+
def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):
|
13 |
+
'''
|
14 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
15 |
+
ann_root (string): directory to store the annotation file
|
16 |
+
'''
|
17 |
+
url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
|
18 |
+
filename = 'flickr30k_train.json'
|
19 |
+
|
20 |
+
download_url(url,ann_root)
|
21 |
+
|
22 |
+
self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
|
23 |
+
self.transform = transform
|
24 |
+
self.image_root = image_root
|
25 |
+
self.max_words = max_words
|
26 |
+
self.prompt = prompt
|
27 |
+
|
28 |
+
self.img_ids = {}
|
29 |
+
n = 0
|
30 |
+
for ann in self.annotation:
|
31 |
+
img_id = ann['image_id']
|
32 |
+
if img_id not in self.img_ids.keys():
|
33 |
+
self.img_ids[img_id] = n
|
34 |
+
n += 1
|
35 |
+
|
36 |
+
def __len__(self):
|
37 |
+
return len(self.annotation)
|
38 |
+
|
39 |
+
def __getitem__(self, index):
|
40 |
+
|
41 |
+
ann = self.annotation[index]
|
42 |
+
|
43 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
44 |
+
image = Image.open(image_path).convert('RGB')
|
45 |
+
image = self.transform(image)
|
46 |
+
|
47 |
+
caption = self.prompt+pre_caption(ann['caption'], self.max_words)
|
48 |
+
|
49 |
+
return image, caption, self.img_ids[ann['image_id']]
|
50 |
+
|
51 |
+
|
52 |
+
class flickr30k_retrieval_eval(Dataset):
|
53 |
+
def __init__(self, transform, image_root, ann_root, split, max_words=30):
|
54 |
+
'''
|
55 |
+
image_root (string): Root directory of images (e.g. flickr30k/)
|
56 |
+
ann_root (string): directory to store the annotation file
|
57 |
+
split (string): val or test
|
58 |
+
'''
|
59 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
|
60 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
|
61 |
+
filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
|
62 |
+
|
63 |
+
download_url(urls[split],ann_root)
|
64 |
+
|
65 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
66 |
+
self.transform = transform
|
67 |
+
self.image_root = image_root
|
68 |
+
|
69 |
+
self.text = []
|
70 |
+
self.image = []
|
71 |
+
self.txt2img = {}
|
72 |
+
self.img2txt = {}
|
73 |
+
|
74 |
+
txt_id = 0
|
75 |
+
for img_id, ann in enumerate(self.annotation):
|
76 |
+
self.image.append(ann['image'])
|
77 |
+
self.img2txt[img_id] = []
|
78 |
+
for i, caption in enumerate(ann['caption']):
|
79 |
+
self.text.append(pre_caption(caption,max_words))
|
80 |
+
self.img2txt[img_id].append(txt_id)
|
81 |
+
self.txt2img[txt_id] = img_id
|
82 |
+
txt_id += 1
|
83 |
+
|
84 |
+
def __len__(self):
|
85 |
+
return len(self.annotation)
|
86 |
+
|
87 |
+
def __getitem__(self, index):
|
88 |
+
|
89 |
+
image_path = os.path.join(self.image_root, self.annotation[index]['image'])
|
90 |
+
image = Image.open(image_path).convert('RGB')
|
91 |
+
image = self.transform(image)
|
92 |
+
|
93 |
+
return image, index
|
data/nlvr_dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.datasets.utils import download_url
|
7 |
+
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from data.utils import pre_caption
|
11 |
+
|
12 |
+
class nlvr_dataset(Dataset):
|
13 |
+
def __init__(self, transform, image_root, ann_root, split):
|
14 |
+
'''
|
15 |
+
image_root (string): Root directory of images
|
16 |
+
ann_root (string): directory to store the annotation file
|
17 |
+
split (string): train, val or test
|
18 |
+
'''
|
19 |
+
urls = {'train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_train.json',
|
20 |
+
'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_dev.json',
|
21 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nlvr_test.json'}
|
22 |
+
filenames = {'train':'nlvr_train.json','val':'nlvr_dev.json','test':'nlvr_test.json'}
|
23 |
+
|
24 |
+
download_url(urls[split],ann_root)
|
25 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
26 |
+
|
27 |
+
self.transform = transform
|
28 |
+
self.image_root = image_root
|
29 |
+
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.annotation)
|
33 |
+
|
34 |
+
|
35 |
+
def __getitem__(self, index):
|
36 |
+
|
37 |
+
ann = self.annotation[index]
|
38 |
+
|
39 |
+
image0_path = os.path.join(self.image_root,ann['images'][0])
|
40 |
+
image0 = Image.open(image0_path).convert('RGB')
|
41 |
+
image0 = self.transform(image0)
|
42 |
+
|
43 |
+
image1_path = os.path.join(self.image_root,ann['images'][1])
|
44 |
+
image1 = Image.open(image1_path).convert('RGB')
|
45 |
+
image1 = self.transform(image1)
|
46 |
+
|
47 |
+
sentence = pre_caption(ann['sentence'], 40)
|
48 |
+
|
49 |
+
if ann['label']=='True':
|
50 |
+
label = 1
|
51 |
+
else:
|
52 |
+
label = 0
|
53 |
+
|
54 |
+
words = sentence.split(' ')
|
55 |
+
|
56 |
+
if 'left' not in words and 'right' not in words:
|
57 |
+
if random.random()<0.5:
|
58 |
+
return image0, image1, sentence, label
|
59 |
+
else:
|
60 |
+
return image1, image0, sentence, label
|
61 |
+
else:
|
62 |
+
if random.random()<0.5:
|
63 |
+
return image0, image1, sentence, label
|
64 |
+
else:
|
65 |
+
new_words = []
|
66 |
+
for word in words:
|
67 |
+
if word=='left':
|
68 |
+
new_words.append('right')
|
69 |
+
elif word=='right':
|
70 |
+
new_words.append('left')
|
71 |
+
else:
|
72 |
+
new_words.append(word)
|
73 |
+
|
74 |
+
sentence = ' '.join(new_words)
|
75 |
+
return image1, image0, sentence, label
|
76 |
+
|
77 |
+
|
78 |
+
|
data/nocaps_dataset.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from torchvision.datasets.utils import download_url
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
|
9 |
+
class nocaps_eval(Dataset):
|
10 |
+
def __init__(self, transform, image_root, ann_root, split):
|
11 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_val.json',
|
12 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/nocaps_test.json'}
|
13 |
+
filenames = {'val':'nocaps_val.json','test':'nocaps_test.json'}
|
14 |
+
|
15 |
+
download_url(urls[split],ann_root)
|
16 |
+
|
17 |
+
self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
|
18 |
+
self.transform = transform
|
19 |
+
self.image_root = image_root
|
20 |
+
|
21 |
+
def __len__(self):
|
22 |
+
return len(self.annotation)
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
|
26 |
+
ann = self.annotation[index]
|
27 |
+
|
28 |
+
image_path = os.path.join(self.image_root,ann['image'])
|
29 |
+
image = Image.open(image_path).convert('RGB')
|
30 |
+
image = self.transform(image)
|
31 |
+
|
32 |
+
return image, int(ann['img_id'])
|
data/pretrain_dataset.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from PIL import ImageFile
|
9 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
10 |
+
Image.MAX_IMAGE_PIXELS = None
|
11 |
+
|
12 |
+
from data.utils import pre_caption
|
13 |
+
import os,glob
|
14 |
+
|
15 |
+
class pretrain_dataset(Dataset):
|
16 |
+
def __init__(self, ann_file, laion_path, transform):
|
17 |
+
|
18 |
+
self.ann_pretrain = []
|
19 |
+
for f in ann_file:
|
20 |
+
print('loading '+f)
|
21 |
+
ann = json.load(open(f,'r'))
|
22 |
+
self.ann_pretrain += ann
|
23 |
+
|
24 |
+
self.laion_path = laion_path
|
25 |
+
if self.laion_path:
|
26 |
+
self.laion_files = glob.glob(os.path.join(laion_path,'*.json'))
|
27 |
+
|
28 |
+
print('loading '+self.laion_files[0])
|
29 |
+
with open(self.laion_files[0],'r') as f:
|
30 |
+
self.ann_laion = json.load(f)
|
31 |
+
|
32 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
33 |
+
else:
|
34 |
+
self.annotation = self.ann_pretrain
|
35 |
+
|
36 |
+
self.transform = transform
|
37 |
+
|
38 |
+
|
39 |
+
def reload_laion(self, epoch):
|
40 |
+
n = epoch%len(self.laion_files)
|
41 |
+
print('loading '+self.laion_files[n])
|
42 |
+
with open(self.laion_files[n],'r') as f:
|
43 |
+
self.ann_laion = json.load(f)
|
44 |
+
|
45 |
+
self.annotation = self.ann_pretrain + self.ann_laion
|
46 |
+
|
47 |
+
|
48 |
+
def __len__(self):
|
49 |
+
return len(self.annotation)
|
50 |
+
|
51 |
+
def __getitem__(self, index):
|
52 |
+
|
53 |
+
ann = self.annotation[index]
|
54 |
+
|
55 |
+
image = Image.open(ann['image']).convert('RGB')
|
56 |
+
image = self.transform(image)
|
57 |
+
caption = pre_caption(ann['caption'],30)
|
58 |
+
|
59 |
+
return image, caption
|
data/utils.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.distributed as dist
|
7 |
+
|
8 |
+
import utils
|
9 |
+
|
10 |
+
def pre_caption(caption,max_words=50):
|
11 |
+
caption = re.sub(
|
12 |
+
r"([.!\"()*#:;~])",
|
13 |
+
' ',
|
14 |
+
caption.lower(),
|
15 |
+
)
|
16 |
+
caption = re.sub(
|
17 |
+
r"\s{2,}",
|
18 |
+
' ',
|
19 |
+
caption,
|
20 |
+
)
|
21 |
+
caption = caption.rstrip('\n')
|
22 |
+
caption = caption.strip(' ')
|
23 |
+
|
24 |
+
#truncate caption
|
25 |
+
caption_words = caption.split(' ')
|
26 |
+
if len(caption_words)>max_words:
|
27 |
+
caption = ' '.join(caption_words[:max_words])
|
28 |
+
|
29 |
+
return caption
|
30 |
+
|
31 |
+
def pre_question(question,max_ques_words=50):
|
32 |
+
question = re.sub(
|
33 |
+
r"([.!\"()*#:;~])",
|
34 |
+
'',
|
35 |
+
question.lower(),
|
36 |
+
)
|
37 |
+
question = question.rstrip(' ')
|
38 |
+
|
39 |
+
#truncate question
|
40 |
+
question_words = question.split(' ')
|
41 |
+
if len(question_words)>max_ques_words:
|
42 |
+
question = ' '.join(question_words[:max_ques_words])
|
43 |
+
|
44 |
+
return question
|
45 |
+
|
46 |
+
|
47 |
+
def save_result(result, result_dir, filename, remove_duplicate=''):
|
48 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,utils.get_rank()))
|
49 |
+
final_result_file = os.path.join(result_dir, '%s.json'%filename)
|
50 |
+
|
51 |
+
json.dump(result,open(result_file,'w'))
|
52 |
+
|
53 |
+
dist.barrier()
|
54 |
+
|
55 |
+
if utils.is_main_process():
|
56 |
+
# combine results from all processes
|
57 |
+
result = []
|
58 |
+
|
59 |
+
for rank in range(utils.get_world_size()):
|
60 |
+
result_file = os.path.join(result_dir, '%s_rank%d.json'%(filename,rank))
|
61 |
+
res = json.load(open(result_file,'r'))
|
62 |
+
result += res
|
63 |
+
|
64 |
+
if remove_duplicate:
|
65 |
+
result_new = []
|
66 |
+
id_list = []
|
67 |
+
for res in result:
|
68 |
+
if res[remove_duplicate] not in id_list:
|
69 |
+
id_list.append(res[remove_duplicate])
|
70 |
+
result_new.append(res)
|
71 |
+
result = result_new
|
72 |
+
|
73 |
+
json.dump(result,open(final_result_file,'w'))
|
74 |
+
print('result file saved to %s'%final_result_file)
|
75 |
+
|
76 |
+
return final_result_file
|
77 |
+
|
78 |
+
|
79 |
+
|
80 |
+
from pycocotools.coco import COCO
|
81 |
+
from pycocoevalcap.eval import COCOEvalCap
|
82 |
+
from torchvision.datasets.utils import download_url
|
83 |
+
|
84 |
+
def coco_caption_eval(coco_gt_root, results_file, split):
|
85 |
+
urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val_gt.json',
|
86 |
+
'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test_gt.json'}
|
87 |
+
filenames = {'val':'coco_karpathy_val_gt.json','test':'coco_karpathy_test_gt.json'}
|
88 |
+
|
89 |
+
download_url(urls[split],coco_gt_root)
|
90 |
+
annotation_file = os.path.join(coco_gt_root,filenames[split])
|
91 |
+
|
92 |
+
# create coco object and coco_result object
|
93 |
+
coco = COCO(annotation_file)
|
94 |
+
coco_result = coco.loadRes(results_file)
|
95 |
+
|
96 |
+
# create coco_eval object by taking coco and coco_result
|
97 |
+
coco_eval = COCOEvalCap(coco, coco_result)
|
98 |
+
|
99 |
+
# evaluate on a subset of images by setting
|
100 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
101 |
+
# please remove this line when evaluating the full validation set
|
102 |
+
# coco_eval.params['image_id'] = coco_result.getImgIds()
|
103 |
+
|
104 |
+
# evaluate results
|
105 |
+
# SPICE will take a few minutes the first time, but speeds up due to caching
|
106 |
+
coco_eval.evaluate()
|
107 |
+
|
108 |
+
# print output evaluation scores
|
109 |
+
for metric, score in coco_eval.eval.items():
|
110 |
+
print(f'{metric}: {score:.3f}')
|
111 |
+
|
112 |
+
return coco_eval
|
data/vqa_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch.utils.data import Dataset
|
8 |
+
from data.utils import pre_question
|
9 |
+
|
10 |
+
from torchvision.datasets.utils import download_url
|
11 |
+
|
12 |
+
class vqa_dataset(Dataset):
|
13 |
+
def __init__(self, transform, ann_root, vqa_root, vg_root, train_files=[], split="train"):
|
14 |
+
self.split = split
|
15 |
+
|
16 |
+
self.transform = transform
|
17 |
+
self.vqa_root = vqa_root
|
18 |
+
self.vg_root = vg_root
|
19 |
+
|
20 |
+
if split=='train':
|
21 |
+
urls = {'vqa_train':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_train.json',
|
22 |
+
'vqa_val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_val.json',
|
23 |
+
'vg_qa':'https://storage.googleapis.com/sfr-vision-language-research/datasets/vg_qa.json'}
|
24 |
+
|
25 |
+
self.annotation = []
|
26 |
+
for f in train_files:
|
27 |
+
download_url(urls[f],ann_root)
|
28 |
+
self.annotation += json.load(open(os.path.join(ann_root,'%s.json'%f),'r'))
|
29 |
+
else:
|
30 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/vqa_test.json',ann_root)
|
31 |
+
self.annotation = json.load(open(os.path.join(ann_root,'vqa_test.json'),'r'))
|
32 |
+
|
33 |
+
download_url('https://storage.googleapis.com/sfr-vision-language-research/datasets/answer_list.json',ann_root)
|
34 |
+
self.answer_list = json.load(open(os.path.join(ann_root,'answer_list.json'),'r'))
|
35 |
+
|
36 |
+
|
37 |
+
def __len__(self):
|
38 |
+
return len(self.annotation)
|
39 |
+
|
40 |
+
def __getitem__(self, index):
|
41 |
+
|
42 |
+
ann = self.annotation[index]
|
43 |
+
|
44 |
+
if ann['dataset']=='vqa':
|
45 |
+
image_path = os.path.join(self.vqa_root,ann['image'])
|
46 |
+
elif ann['dataset']=='vg':
|
47 |
+
image_path = os.path.join(self.vg_root,ann['image'])
|
48 |
+
|
49 |
+
image = Image.open(image_path).convert('RGB')
|
50 |
+
image = self.transform(image)
|
51 |
+
|
52 |
+
if self.split == 'test':
|
53 |
+
question = pre_question(ann['question'])
|
54 |
+
question_id = ann['question_id']
|
55 |
+
return image, question, question_id
|
56 |
+
|
57 |
+
|
58 |
+
elif self.split=='train':
|
59 |
+
|
60 |
+
question = pre_question(ann['question'])
|
61 |
+
|
62 |
+
if ann['dataset']=='vqa':
|
63 |
+
answer_weight = {}
|
64 |
+
for answer in ann['answer']:
|
65 |
+
if answer in answer_weight.keys():
|
66 |
+
answer_weight[answer] += 1/len(ann['answer'])
|
67 |
+
else:
|
68 |
+
answer_weight[answer] = 1/len(ann['answer'])
|
69 |
+
|
70 |
+
answers = list(answer_weight.keys())
|
71 |
+
weights = list(answer_weight.values())
|
72 |
+
|
73 |
+
elif ann['dataset']=='vg':
|
74 |
+
answers = [ann['answer']]
|
75 |
+
weights = [0.2]
|
76 |
+
|
77 |
+
return image, question, answers, weights
|
78 |
+
|
79 |
+
|
80 |
+
def vqa_collate_fn(batch):
|
81 |
+
image_list, question_list, answer_list, weight_list, n = [], [], [], [], []
|
82 |
+
for image, question, answer, weights in batch:
|
83 |
+
image_list.append(image)
|
84 |
+
question_list.append(question)
|
85 |
+
weight_list += weights
|
86 |
+
answer_list += answer
|
87 |
+
n.append(len(answer))
|
88 |
+
return torch.stack(image_list,dim=0), question_list, answer_list, torch.Tensor(weight_list), n
|
elephant.jpg
ADDED
eval_nocaps.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from data import create_dataset, create_sampler, create_loader
|
28 |
+
from data.utils import save_result
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluate(model, data_loader, device, config):
|
32 |
+
# evaluate
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
print_freq = 10
|
38 |
+
|
39 |
+
result = []
|
40 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
41 |
+
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
45 |
+
min_length=config['min_length'], repetition_penalty=1.1)
|
46 |
+
|
47 |
+
for caption, img_id in zip(captions, image_id):
|
48 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def main(args, config):
|
54 |
+
utils.init_distributed_mode(args)
|
55 |
+
|
56 |
+
device = torch.device(args.device)
|
57 |
+
|
58 |
+
# fix the seed for reproducibility
|
59 |
+
seed = args.seed + utils.get_rank()
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
random.seed(seed)
|
63 |
+
cudnn.benchmark = True
|
64 |
+
|
65 |
+
#### Dataset ####
|
66 |
+
print("Creating captioning dataset")
|
67 |
+
val_dataset, test_dataset = create_dataset('nocaps', config)
|
68 |
+
|
69 |
+
if args.distributed:
|
70 |
+
num_tasks = utils.get_world_size()
|
71 |
+
global_rank = utils.get_rank()
|
72 |
+
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
73 |
+
else:
|
74 |
+
samplers = [None,None]
|
75 |
+
|
76 |
+
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
77 |
+
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
78 |
+
is_trains=[False, False], collate_fns=[None,None])
|
79 |
+
|
80 |
+
#### Model ####
|
81 |
+
print("Creating model")
|
82 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
83 |
+
prompt=config['prompt'])
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
|
87 |
+
model_without_ddp = model
|
88 |
+
if args.distributed:
|
89 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
90 |
+
model_without_ddp = model.module
|
91 |
+
|
92 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
93 |
+
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
94 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
95 |
+
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
101 |
+
parser.add_argument('--output_dir', default='output/NoCaps')
|
102 |
+
parser.add_argument('--device', default='cuda')
|
103 |
+
parser.add_argument('--seed', default=42, type=int)
|
104 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
105 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
106 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
110 |
+
|
111 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
112 |
+
|
113 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
114 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
117 |
+
|
118 |
+
main(args, config)
|
examples/ex1.jpg
ADDED
examples/ex2.jpg
ADDED
examples/ex3.jpg
ADDED
extras/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
extras/sample-images/0.JPG
ADDED
extras/sample-images/1.JPG
ADDED
extras/sample-images/10.jpg
ADDED
extras/sample-images/2.jpg
ADDED
extras/sample-images/3.jpg
ADDED
extras/sample-images/4.jpg
ADDED
extras/sample-images/5.jpg
ADDED
extras/sample-images/6.JPG
ADDED
extras/sample-images/7.JPG
ADDED
extras/sample-images/8.jpg
ADDED
extras/sample-images/9.jpg
ADDED
foo.png
ADDED
gradio_cached_examples/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Output
|
2 |
+
caption: a painting of a starry night over a city
|
local_run.ipynb
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"Running on local URL: http://127.0.0.1:7860/\n",
|
13 |
+
"\n",
|
14 |
+
"To create a public link, set `share=True` in `launch()`.\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"data": {
|
19 |
+
"text/html": [
|
20 |
+
"\n",
|
21 |
+
" <iframe\n",
|
22 |
+
" width=\"900\"\n",
|
23 |
+
" height=\"500\"\n",
|
24 |
+
" src=\"http://127.0.0.1:7860/\"\n",
|
25 |
+
" frameborder=\"0\"\n",
|
26 |
+
" allowfullscreen\n",
|
27 |
+
" \n",
|
28 |
+
" ></iframe>\n",
|
29 |
+
" "
|
30 |
+
],
|
31 |
+
"text/plain": [
|
32 |
+
"<IPython.lib.display.IFrame at 0x7fbca787f520>"
|
33 |
+
]
|
34 |
+
},
|
35 |
+
"metadata": {},
|
36 |
+
"output_type": "display_data"
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"data": {
|
40 |
+
"text/plain": [
|
41 |
+
"(<fastapi.applications.FastAPI at 0x7fbcc67ceeb0>,\n",
|
42 |
+
" 'http://127.0.0.1:7860/',\n",
|
43 |
+
" None)"
|
44 |
+
]
|
45 |
+
},
|
46 |
+
"execution_count": 1,
|
47 |
+
"metadata": {},
|
48 |
+
"output_type": "execute_result"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"name": "stderr",
|
52 |
+
"output_type": "stream",
|
53 |
+
"text": [
|
54 |
+
"2022-02-09 14:10:22.417549: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
55 |
+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
|
56 |
+
]
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"name": "stdout",
|
60 |
+
"output_type": "stream",
|
61 |
+
"text": [
|
62 |
+
"\n",
|
63 |
+
"\n",
|
64 |
+
"\n",
|
65 |
+
"Total workers: 5\n",
|
66 |
+
"Number of Helmets: 4\n",
|
67 |
+
"Number of Vests: 0\n",
|
68 |
+
"dict vals:\n",
|
69 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
70 |
+
"\n",
|
71 |
+
"\n",
|
72 |
+
"\n",
|
73 |
+
"Total workers: 5\n",
|
74 |
+
"dict vals:\n",
|
75 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
76 |
+
"\n",
|
77 |
+
"\n",
|
78 |
+
"\n",
|
79 |
+
"Total workers: 5\n",
|
80 |
+
"Workers wearing helmet and vest: 0\n",
|
81 |
+
"Workers wearing only vest: 0\n",
|
82 |
+
"Workers wearing only helmet: 5\n",
|
83 |
+
"dict vals:\n",
|
84 |
+
"{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n",
|
85 |
+
"\n",
|
86 |
+
"\n",
|
87 |
+
"\n",
|
88 |
+
"Total workers: 5\n",
|
89 |
+
"Number of Helmets: 4\n",
|
90 |
+
"Number of Vests: 0\n",
|
91 |
+
"dict vals:\n",
|
92 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
93 |
+
"WARNING:tensorflow:5 out of the last 5 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fbc729998b0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
|
94 |
+
"\n",
|
95 |
+
"\n",
|
96 |
+
"\n",
|
97 |
+
"Total workers: 5\n",
|
98 |
+
"Workers wearing helmet and vest: 0\n",
|
99 |
+
"Workers wearing only vest: 0\n",
|
100 |
+
"Workers wearing only helmet: 5\n",
|
101 |
+
"dict vals:\n",
|
102 |
+
"{'W': 5, 'WH': 5, 'WHV': 0, 'WV': 0}\n",
|
103 |
+
"WARNING:tensorflow:6 out of the last 6 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7fbc979e9ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.\n",
|
104 |
+
"\n",
|
105 |
+
"\n",
|
106 |
+
"\n",
|
107 |
+
"Total workers: 3\n",
|
108 |
+
"dict vals:\n",
|
109 |
+
"{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
110 |
+
"\n",
|
111 |
+
"\n",
|
112 |
+
"\n",
|
113 |
+
"Total workers: 3\n",
|
114 |
+
"Workers wearing helmet and vest: 3\n",
|
115 |
+
"Workers wearing only vest: 0\n",
|
116 |
+
"Workers wearing only helmet: 0\n",
|
117 |
+
"dict vals:\n",
|
118 |
+
"{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n",
|
119 |
+
"\n",
|
120 |
+
"\n",
|
121 |
+
"\n",
|
122 |
+
"Total workers: 3\n",
|
123 |
+
"Number of Helmets: 3\n",
|
124 |
+
"Number of Vests: 1\n",
|
125 |
+
"dict vals:\n",
|
126 |
+
"{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
127 |
+
"\n",
|
128 |
+
"\n",
|
129 |
+
"\n",
|
130 |
+
"Total workers: 5\n",
|
131 |
+
"Number of Helmets: 4\n",
|
132 |
+
"Number of Vests: 0\n",
|
133 |
+
"dict vals:\n",
|
134 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
135 |
+
"\n",
|
136 |
+
"\n",
|
137 |
+
"\n",
|
138 |
+
"Total workers: 6\n",
|
139 |
+
"Workers wearing helmet and vest: 0\n",
|
140 |
+
"Workers wearing only vest: 0\n",
|
141 |
+
"Workers wearing only helmet: 4\n",
|
142 |
+
"Workers not wearing helmet and vest: 2\n",
|
143 |
+
"\n",
|
144 |
+
"\n",
|
145 |
+
"dict vals:\n",
|
146 |
+
"{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
|
147 |
+
"\n",
|
148 |
+
"\n",
|
149 |
+
"\n",
|
150 |
+
"Total workers: 6\n",
|
151 |
+
"dict vals:\n",
|
152 |
+
"{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
153 |
+
"\n",
|
154 |
+
"\n",
|
155 |
+
"\n",
|
156 |
+
"Total workers: 5\n",
|
157 |
+
"Number of Helmets: 4\n",
|
158 |
+
"Number of Vests: 0\n",
|
159 |
+
"dict vals:\n",
|
160 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
161 |
+
"\n",
|
162 |
+
"\n",
|
163 |
+
"\n",
|
164 |
+
"Total workers: 6\n",
|
165 |
+
"dict vals:\n",
|
166 |
+
"{'W': 6, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
167 |
+
"\n",
|
168 |
+
"\n",
|
169 |
+
"\n",
|
170 |
+
"Total workers: 6\n",
|
171 |
+
"Workers wearing helmet and vest: 0\n",
|
172 |
+
"Workers wearing only vest: 0\n",
|
173 |
+
"Workers wearing only helmet: 4\n",
|
174 |
+
"Workers not wearing helmet and vest: 2\n",
|
175 |
+
"\n",
|
176 |
+
"\n",
|
177 |
+
"dict vals:\n",
|
178 |
+
"{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
|
179 |
+
"\n",
|
180 |
+
"\n",
|
181 |
+
"\n",
|
182 |
+
"Total workers: 1\n",
|
183 |
+
"Number of Helmets: 1\n",
|
184 |
+
"Number of Vests: 0\n",
|
185 |
+
"dict vals:\n",
|
186 |
+
"{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
187 |
+
"\n",
|
188 |
+
"\n",
|
189 |
+
"\n",
|
190 |
+
"Total workers: 1\n",
|
191 |
+
"Workers wearing helmet and vest: 0\n",
|
192 |
+
"Workers wearing only vest: 0\n",
|
193 |
+
"Workers wearing only helmet: 1\n",
|
194 |
+
"dict vals:\n",
|
195 |
+
"{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n",
|
196 |
+
"\n",
|
197 |
+
"\n",
|
198 |
+
"\n",
|
199 |
+
"Total workers: 1\n",
|
200 |
+
"dict vals:\n",
|
201 |
+
"{'W': 1, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
202 |
+
"\n",
|
203 |
+
"\n",
|
204 |
+
"\n",
|
205 |
+
"Total workers: 1\n",
|
206 |
+
"Workers wearing helmet and vest: 0\n",
|
207 |
+
"Workers wearing only vest: 0\n",
|
208 |
+
"Workers wearing only helmet: 1\n",
|
209 |
+
"dict vals:\n",
|
210 |
+
"{'W': 1, 'WH': 1, 'WHV': 0, 'WV': 0}\n",
|
211 |
+
"\n",
|
212 |
+
"\n",
|
213 |
+
"\n",
|
214 |
+
"Total workers: 5\n",
|
215 |
+
"Number of Helmets: 4\n",
|
216 |
+
"Number of Vests: 0\n",
|
217 |
+
"dict vals:\n",
|
218 |
+
"{'W': 5, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
219 |
+
"\n",
|
220 |
+
"\n",
|
221 |
+
"\n",
|
222 |
+
"Total workers: 6\n",
|
223 |
+
"Workers wearing helmet and vest: 0\n",
|
224 |
+
"Workers wearing only vest: 0\n",
|
225 |
+
"Workers wearing only helmet: 4\n",
|
226 |
+
"Workers not wearing helmet and vest: 2\n",
|
227 |
+
"\n",
|
228 |
+
"\n",
|
229 |
+
"dict vals:\n",
|
230 |
+
"{'W': 6, 'WH': 4, 'WHV': 0, 'WV': 0}\n",
|
231 |
+
"\n",
|
232 |
+
"\n",
|
233 |
+
"\n",
|
234 |
+
"Total workers: 3\n",
|
235 |
+
"Workers wearing helmet and vest: 3\n",
|
236 |
+
"Workers wearing only vest: 0\n",
|
237 |
+
"Workers wearing only helmet: 0\n",
|
238 |
+
"dict vals:\n",
|
239 |
+
"{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n",
|
240 |
+
"\n",
|
241 |
+
"\n",
|
242 |
+
"\n",
|
243 |
+
"Total workers: 3\n",
|
244 |
+
"dict vals:\n",
|
245 |
+
"{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
246 |
+
"\n",
|
247 |
+
"\n",
|
248 |
+
"\n",
|
249 |
+
"Total workers: 3\n",
|
250 |
+
"Number of Helmets: 3\n",
|
251 |
+
"Number of Vests: 1\n",
|
252 |
+
"dict vals:\n",
|
253 |
+
"{'W': 3, 'WH': 0, 'WHV': 0, 'WV': 0}\n",
|
254 |
+
"\n",
|
255 |
+
"\n",
|
256 |
+
"\n",
|
257 |
+
"Total workers: 3\n",
|
258 |
+
"Workers wearing helmet and vest: 3\n",
|
259 |
+
"Workers wearing only vest: 0\n",
|
260 |
+
"Workers wearing only helmet: 0\n",
|
261 |
+
"dict vals:\n",
|
262 |
+
"{'W': 3, 'WH': 0, 'WHV': 3, 'WV': 0}\n"
|
263 |
+
]
|
264 |
+
}
|
265 |
+
],
|
266 |
+
"source": [
|
267 |
+
"import numpy as np\n",
|
268 |
+
"import run_code\n",
|
269 |
+
"import cv2\n",
|
270 |
+
"import gradio as gr\n",
|
271 |
+
"\n",
|
272 |
+
"\n",
|
273 |
+
"def sepia(Input_Image, Approach):\n",
|
274 |
+
" pil_image = Input_Image\n",
|
275 |
+
" open_cv_image = np.asarray(pil_image)\n",
|
276 |
+
" # Convert RGB to BGR\n",
|
277 |
+
" #open_cv_image = open_cv_image[:, :, ::-1].copy()\n",
|
278 |
+
" #Approach = 3\n",
|
279 |
+
" sepia_img = run_code.run(open_cv_image, Approach)\n",
|
280 |
+
" images = sepia_img['img']\n",
|
281 |
+
" texts= sepia_img['text']\n",
|
282 |
+
" #print (labels)\n",
|
283 |
+
" return images, texts\n",
|
284 |
+
"\n",
|
285 |
+
"image = [gr.inputs.Image(type=\"pil\"), gr.inputs.Radio([1, 2, 3])]\n",
|
286 |
+
"#output = [\"image\", gr.outputs.Label(num_top_classes=4)]\n",
|
287 |
+
"output = [\"image\", gr.outputs.Textbox(type=\"auto\")]\n",
|
288 |
+
"#output = gr.outputs.Label(num_top_classes=4)\n",
|
289 |
+
"\n",
|
290 |
+
"title=\"Real-time Detection of Personal-Protective-Equipment (PPE)\"\n",
|
291 |
+
"description=\"This demo is the implementation of Real-time Detection of Personal-Protective-Equipment (PPE) paper https://github.com/ciber-lab/pictor-ppe\" \\\n",
|
292 |
+
" \" - by Sanjay Kamath \"\n",
|
293 |
+
"examples = [[\"examples/ex1.jpg\", 1], [\"examples/ex2.jpg\", 2], [\"examples/ex3.jpg\", 3]]\n",
|
294 |
+
"\n",
|
295 |
+
"#iface = gr.Interface(sepia , [ gr.inputs.Image(shape=(200, 200)), gr.inputs.Radio([1, 2, 3])], \"image\", title=title,\n",
|
296 |
+
"# examples = [[\"examples/ex1.jpg\"], [\"examples/ex2.jpg\"], [\"examples/ex3.jpg\"]],\n",
|
297 |
+
"# description=description)\n",
|
298 |
+
"\n",
|
299 |
+
"iface = gr.Interface(fn=sepia, inputs=image, outputs=output, title=title, description=description, examples=examples)\n",
|
300 |
+
"\n",
|
301 |
+
"iface.launch()"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
{
|
305 |
+
"cell_type": "code",
|
306 |
+
"execution_count": null,
|
307 |
+
"metadata": {},
|
308 |
+
"outputs": [],
|
309 |
+
"source": []
|
310 |
+
},
|
311 |
+
{
|
312 |
+
"cell_type": "code",
|
313 |
+
"execution_count": null,
|
314 |
+
"metadata": {},
|
315 |
+
"outputs": [],
|
316 |
+
"source": []
|
317 |
+
},
|
318 |
+
{
|
319 |
+
"cell_type": "code",
|
320 |
+
"execution_count": null,
|
321 |
+
"metadata": {},
|
322 |
+
"outputs": [],
|
323 |
+
"source": []
|
324 |
+
}
|
325 |
+
],
|
326 |
+
"metadata": {
|
327 |
+
"kernelspec": {
|
328 |
+
"display_name": "Python 3 (ipykernel)",
|
329 |
+
"language": "python",
|
330 |
+
"name": "python3"
|
331 |
+
},
|
332 |
+
"language_info": {
|
333 |
+
"codemirror_mode": {
|
334 |
+
"name": "ipython",
|
335 |
+
"version": 3
|
336 |
+
},
|
337 |
+
"file_extension": ".py",
|
338 |
+
"mimetype": "text/x-python",
|
339 |
+
"name": "python",
|
340 |
+
"nbconvert_exporter": "python",
|
341 |
+
"pygments_lexer": "ipython3",
|
342 |
+
"version": "3.8.12"
|
343 |
+
}
|
344 |
+
},
|
345 |
+
"nbformat": 4,
|
346 |
+
"nbformat_minor": 4
|
347 |
+
}
|
model-data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
model-data/weights/pictor-ppe-v302-a1-yolo-v3-weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3ec800aa5acdd9719ff5e63b34d1374e5c8a31e17f38f3a8250bf1aeeac1a972
|
3 |
+
size 246910096
|
model-data/weights/pictor-ppe-v302-a2-yolo-v3-weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:317831ba378b8ec02e24e57859876eb0348284c8a75155143c9df85ee478c47b
|
3 |
+
size 246931600
|
model-data/weights/pictor-ppe-v302-a3-yolo-v3-weights.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d06d4956d0f6b3ac71f02e103e9efdc4b222ce83aeae232f65ee6c04ee1dd2d7
|
3 |
+
size 246867088
|
model-data/weights/readme.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Download the trained weights of YOLO models ([Google Drive folder](https://drive.google.com/drive/folders/13tCdROHnS0c5VibW1VO8pOEj0rXEvvGj?usp=sharing)) and put in this folder.
|
modelsn/__init__.py
ADDED
File without changes
|