Spaces:
Sleeping
Sleeping
Update with md5sum and half precision inference
Browse files- app.py +1 -11
- app_vqa.py +1 -1
- label_prettify.py +86 -74
- prismer/configs/experts.yaml +3 -2
- prismer/dataset/caption_dataset.py +3 -5
- prismer/dataset/utils.py +13 -9
- prismer/experts/depth/generate_dataset.py +4 -6
- prismer/experts/edge/generate_dataset.py +4 -6
- prismer/experts/generate_depth.py +1 -2
- prismer/experts/generate_edge.py +1 -2
- prismer/experts/generate_normal.py +1 -2
- prismer/experts/generate_objdet.py +1 -2
- prismer/experts/generate_ocrdet.py +1 -2
- prismer/experts/generate_segmentation.py +1 -2
- prismer/experts/model_bank.py +2 -0
- prismer/experts/normal/generate_dataset.py +4 -6
- prismer/experts/obj_detection/generate_dataset.py +6 -7
- prismer/experts/ocr_detection/generate_dataset.py +4 -6
- prismer/experts/segmentation/generate_dataset.py +3 -5
- prismer/helpers/images/COCO_test2015_000000000014.jpg +0 -0
- prismer/helpers/images/COCO_test2015_000000000016.jpg +0 -0
- prismer/helpers/images/COCO_test2015_000000000019.jpg +0 -0
- prismer/helpers/images/COCO_test2015_000000000128.jpg +0 -0
- prismer/helpers/images/COCO_test2015_000000000155.jpg +0 -0
- prismer/helpers/intro.png +0 -0
- prismer/model/prismer.py +6 -2
- prismer_model.py +64 -29
- requirements.txt +1 -1
app.py
CHANGED
@@ -3,18 +3,8 @@
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import os
|
6 |
-
import shutil
|
7 |
-
import subprocess
|
8 |
-
|
9 |
import gradio as gr
|
10 |
|
11 |
-
if os.getenv('SYSTEM') == 'spaces':
|
12 |
-
with open('patch') as f:
|
13 |
-
subprocess.run('patch -p1'.split(), cwd='prismer', stdin=f)
|
14 |
-
shutil.copytree('prismer/helpers/images',
|
15 |
-
'prismer/images',
|
16 |
-
dirs_exist_ok=True)
|
17 |
-
|
18 |
from app_caption import create_demo as create_demo_caption
|
19 |
from app_vqa import create_demo as create_demo_vqa
|
20 |
from prismer_model import build_deformable_conv, download_models
|
@@ -36,7 +26,7 @@ if (SPACE_ID := os.getenv('SPACE_ID')) is not None:
|
|
36 |
description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
|
37 |
|
38 |
|
39 |
-
with gr.Blocks() as demo:
|
40 |
gr.Markdown(description)
|
41 |
with gr.Tabs():
|
42 |
with gr.TabItem('Zero-shot Image Captioning'):
|
|
|
3 |
from __future__ import annotations
|
4 |
|
5 |
import os
|
|
|
|
|
|
|
6 |
import gradio as gr
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from app_caption import create_demo as create_demo_caption
|
9 |
from app_vqa import create_demo as create_demo_vqa
|
10 |
from prismer_model import build_deformable_conv, download_models
|
|
|
26 |
description += f'For faster inference without waiting in queue, you may duplicate the space and upgrade to GPU in settings. <a href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img style="display: inline; margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space" /></a>'
|
27 |
|
28 |
|
29 |
+
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
|
30 |
gr.Markdown(description)
|
31 |
with gr.Tabs():
|
32 |
with gr.TabItem('Zero-shot Image Captioning'):
|
app_vqa.py
CHANGED
@@ -35,7 +35,7 @@ def create_demo() -> gr.Blocks:
|
|
35 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
36 |
ex_questions = ['What is the man on the left doing?',
|
37 |
'What is this person doing?',
|
38 |
-
'How many cows in this image?',
|
39 |
'What is the type of animal in this image?',
|
40 |
'What toy is it?']
|
41 |
examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
|
|
|
35 |
paths = sorted(pathlib.Path('prismer/images').glob('*'))
|
36 |
ex_questions = ['What is the man on the left doing?',
|
37 |
'What is this person doing?',
|
38 |
+
'How many cows are in this image?',
|
39 |
'What is the type of animal in this image?',
|
40 |
'What toy is it?']
|
41 |
examples = [[path.as_posix(), 'Prismer-Base', ex_questions[i]] for i, path in enumerate(paths)]
|
label_prettify.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import matplotlib
|
7 |
import numpy as np
|
|
|
8 |
|
9 |
from prismer.utils import create_ade20k_label_colormap
|
10 |
|
@@ -23,101 +24,109 @@ def islight(rgb):
|
|
23 |
|
24 |
|
25 |
def depth_prettify(file_path):
|
26 |
-
|
27 |
-
|
|
|
|
|
28 |
|
29 |
|
30 |
def obj_detection_prettify(rgb_path, path_name):
|
31 |
-
|
32 |
-
|
33 |
-
|
|
|
|
|
34 |
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
|
59 |
|
60 |
def seg_prettify(rgb_path, file_name):
|
61 |
-
|
62 |
-
|
|
|
|
|
63 |
|
64 |
-
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
|
88 |
|
89 |
def ocr_detection_prettify(rgb_path, file_name):
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
94 |
|
95 |
-
|
96 |
-
|
97 |
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
|
111 |
-
|
112 |
-
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
|
122 |
|
123 |
def label_prettify(rgb_path, expert_paths):
|
@@ -130,4 +139,7 @@ def label_prettify(rgb_path, expert_paths):
|
|
130 |
ocr_detection_prettify(rgb_path, expert_path)
|
131 |
elif 'obj' in expert_path:
|
132 |
obj_detection_prettify(rgb_path, expert_path)
|
133 |
-
|
|
|
|
|
|
|
|
5 |
import matplotlib.pyplot as plt
|
6 |
import matplotlib
|
7 |
import numpy as np
|
8 |
+
import shutil
|
9 |
|
10 |
from prismer.utils import create_ade20k_label_colormap
|
11 |
|
|
|
24 |
|
25 |
|
26 |
def depth_prettify(file_path):
|
27 |
+
pretty_path = file_path.replace('.png', '_p.png')
|
28 |
+
if not os.path.exists(pretty_path):
|
29 |
+
depth = plt.imread(file_path)
|
30 |
+
plt.imsave(pretty_path, depth, cmap='rainbow')
|
31 |
|
32 |
|
33 |
def obj_detection_prettify(rgb_path, path_name):
|
34 |
+
pretty_path = path_name.replace('.png', '_p.png')
|
35 |
+
if not os.path.exists(pretty_path):
|
36 |
+
rgb = plt.imread(rgb_path)
|
37 |
+
obj_labels = plt.imread(path_name)
|
38 |
+
obj_labels_dict = json.load(open(path_name.replace('.png', '.json')))
|
39 |
|
40 |
+
plt.imshow(rgb)
|
41 |
|
42 |
+
if len(np.unique(obj_labels)) == 1:
|
43 |
+
plt.axis('off')
|
44 |
+
plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
|
45 |
+
plt.close()
|
46 |
+
else:
|
47 |
+
num_objs = np.unique(obj_labels)[:-1].max()
|
48 |
+
plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
|
49 |
+
cmap = matplotlib.colormaps.get_cmap('terrain')
|
50 |
+
for i in np.unique(obj_labels)[:-1]:
|
51 |
+
obj_idx_all = np.where(obj_labels == i)
|
52 |
+
x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
|
53 |
+
obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
|
54 |
+
obj_name = obj_name.split(',')[0]
|
55 |
+
if islight([c*255 for c in cmap(i / num_objs)[:3]]):
|
56 |
+
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
57 |
+
else:
|
58 |
+
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
59 |
+
|
60 |
+
plt.axis('off')
|
61 |
+
plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
|
62 |
+
plt.close()
|
63 |
|
64 |
|
65 |
def seg_prettify(rgb_path, file_name):
|
66 |
+
pretty_path = file_name.replace('.png', '_p.png')
|
67 |
+
if not os.path.exists(pretty_path):
|
68 |
+
rgb = plt.imread(rgb_path)
|
69 |
+
seg_labels = plt.imread(file_name)
|
70 |
|
71 |
+
plt.imshow(rgb)
|
72 |
|
73 |
+
seg_map = np.zeros(list(seg_labels.shape) + [3], dtype=np.int16)
|
74 |
+
for i in np.unique(seg_labels):
|
75 |
+
seg_map[seg_labels == i] = ade_color[int(i * 255)]
|
76 |
|
77 |
+
plt.imshow(seg_map, alpha=0.8)
|
78 |
|
79 |
+
for i in np.unique(seg_labels):
|
80 |
+
obj_idx_all = np.where(seg_labels == i)
|
81 |
+
if len(obj_idx_all[0]) > 20: # only plot the label with its number of labelled pixel more than 20
|
82 |
+
obj_idx = random.randint(0, len(obj_idx_all[0]) - 1)
|
83 |
+
x, y = obj_idx_all[1][obj_idx], obj_idx_all[0][obj_idx]
|
84 |
+
obj_name = coco_label_map[int(i * 255)]
|
85 |
+
obj_name = obj_name.split(',')[0]
|
86 |
+
if islight(seg_map[int(y), int(x)]):
|
87 |
+
plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
88 |
+
else:
|
89 |
+
plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
90 |
|
91 |
+
plt.axis('off')
|
92 |
+
plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
|
93 |
+
plt.close()
|
94 |
|
95 |
|
96 |
def ocr_detection_prettify(rgb_path, file_name):
|
97 |
+
pretty_path = file_name.replace('.png', '_p.png')
|
98 |
+
if not os.path.exists(pretty_path):
|
99 |
+
if os.path.exists(file_name):
|
100 |
+
rgb = plt.imread(rgb_path)
|
101 |
+
ocr_labels = plt.imread(file_name)
|
102 |
+
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
103 |
|
104 |
+
plt.imshow(rgb)
|
105 |
+
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
|
106 |
|
107 |
+
for i in np.unique(ocr_labels)[:-1]:
|
108 |
+
text_idx_all = np.where(ocr_labels == i)
|
109 |
+
x, y = text_idx_all[1].mean(), text_idx_all[0].mean()
|
110 |
+
text = ocr_labels_dict[int(i * 255)]['text']
|
111 |
+
plt.text(x, y, text, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
112 |
|
113 |
+
plt.axis('off')
|
114 |
+
plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
|
115 |
+
plt.close()
|
116 |
+
else:
|
117 |
+
rgb = plt.imread(rgb_path)
|
118 |
+
ocr_labels = np.ones_like(rgb, dtype=np.float32())
|
119 |
|
120 |
+
plt.imshow(rgb)
|
121 |
+
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
|
122 |
|
123 |
+
x, y = rgb.shape[1] / 2, rgb.shape[0] / 2
|
124 |
+
plt.text(x, y, 'No text detected', c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
|
125 |
+
plt.axis('off')
|
126 |
|
127 |
+
os.makedirs(os.path.dirname(file_name), exist_ok=True)
|
128 |
+
plt.savefig(pretty_path, bbox_inches='tight', transparent=True, pad_inches=0)
|
129 |
+
plt.close()
|
130 |
|
131 |
|
132 |
def label_prettify(rgb_path, expert_paths):
|
|
|
139 |
ocr_detection_prettify(rgb_path, expert_path)
|
140 |
elif 'obj' in expert_path:
|
141 |
obj_detection_prettify(rgb_path, expert_path)
|
142 |
+
else:
|
143 |
+
pretty_path = expert_path.replace('.png', '_p.png')
|
144 |
+
if not os.path.exists(pretty_path):
|
145 |
+
shutil.copyfile(expert_path, pretty_path)
|
prismer/configs/experts.yaml
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
-
data_path:
|
2 |
-
|
|
|
|
1 |
+
data_path: helpers
|
2 |
+
im_name: 87dfaeb4978ce05aa7be5e5b4cc1273a
|
3 |
+
save_path: helpers/labels
|
prismer/dataset/caption_dataset.py
CHANGED
@@ -32,10 +32,7 @@ class Caption(Dataset):
|
|
32 |
elif self.dataset == 'nocaps':
|
33 |
self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
|
34 |
elif self.dataset == 'demo':
|
35 |
-
|
36 |
-
self.data_list = [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpg')]
|
37 |
-
self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.png')]
|
38 |
-
self.data_list += [{'image': data} for f in data_folders for data in glob.glob(f + '*.jpeg')]
|
39 |
|
40 |
def __len__(self):
|
41 |
return len(self.data_list)
|
@@ -50,10 +47,11 @@ class Caption(Dataset):
|
|
50 |
elif self.dataset == 'demo':
|
51 |
img_path_split = self.data_list[index]['image'].split('/')
|
52 |
img_name = img_path_split[-2] + '/' + img_path_split[-1]
|
53 |
-
image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts)
|
54 |
|
55 |
experts = self.transform(image, labels)
|
56 |
experts = post_label_process(experts, labels_info)
|
|
|
57 |
|
58 |
if self.train:
|
59 |
caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
|
|
|
32 |
elif self.dataset == 'nocaps':
|
33 |
self.data_list = json.load(open(os.path.join(self.data_path, 'nocaps_val.json'), 'r'))
|
34 |
elif self.dataset == 'demo':
|
35 |
+
self.data_list = [{'image': f'helpers/images/{config["im_name"]}.jpg'}]
|
|
|
|
|
|
|
36 |
|
37 |
def __len__(self):
|
38 |
return len(self.data_list)
|
|
|
47 |
elif self.dataset == 'demo':
|
48 |
img_path_split = self.data_list[index]['image'].split('/')
|
49 |
img_name = img_path_split[-2] + '/' + img_path_split[-1]
|
50 |
+
image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts)
|
51 |
|
52 |
experts = self.transform(image, labels)
|
53 |
experts = post_label_process(experts, labels_info)
|
54 |
+
experts['rgb'] = experts['rgb'].half()
|
55 |
|
56 |
if self.train:
|
57 |
caption = pre_caption(self.prefix + ' ' + self.data_list[index]['caption'], max_words=30)
|
prismer/dataset/utils.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import os
|
|
|
8 |
import re
|
9 |
import json
|
10 |
import torch
|
@@ -14,10 +15,12 @@ import torchvision.transforms as transforms
|
|
14 |
import torchvision.transforms.functional as transforms_f
|
15 |
from dataset.randaugment import RandAugment
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
21 |
|
22 |
|
23 |
class Transform:
|
@@ -119,7 +122,8 @@ def post_label_process(inputs, labels_info):
|
|
119 |
for exp in inputs:
|
120 |
if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
|
121 |
inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
|
122 |
-
|
|
|
123 |
elif exp == 'seg_coco': # in-paint with CLIP features
|
124 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
125 |
for l in inputs[exp].unique():
|
@@ -127,7 +131,7 @@ def post_label_process(inputs, labels_info):
|
|
127 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
128 |
else:
|
129 |
text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
|
130 |
-
inputs[exp] = text_emb
|
131 |
|
132 |
elif exp == 'seg_ade': # in-paint with CLIP features
|
133 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
@@ -136,7 +140,7 @@ def post_label_process(inputs, labels_info):
|
|
136 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
137 |
else:
|
138 |
text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
|
139 |
-
inputs[exp] = text_emb
|
140 |
|
141 |
elif exp == 'obj_detection': # in-paint with CLIP features
|
142 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
@@ -146,7 +150,7 @@ def post_label_process(inputs, labels_info):
|
|
146 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
147 |
else:
|
148 |
text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
|
149 |
-
inputs[exp] = {'label': text_emb, 'instance': inputs[exp]}
|
150 |
|
151 |
elif exp == 'ocr_detection': # in-paint with CLIP features
|
152 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
@@ -156,7 +160,7 @@ def post_label_process(inputs, labels_info):
|
|
156 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
157 |
else:
|
158 |
text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
|
159 |
-
inputs[exp] = text_emb
|
160 |
return inputs
|
161 |
|
162 |
|
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import os
|
8 |
+
import pathlib
|
9 |
import re
|
10 |
import json
|
11 |
import torch
|
|
|
15 |
import torchvision.transforms.functional as transforms_f
|
16 |
from dataset.randaugment import RandAugment
|
17 |
|
18 |
+
cur_dir = pathlib.Path(__file__).parent
|
19 |
+
|
20 |
+
COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features']
|
21 |
+
ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features']
|
22 |
+
DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features']
|
23 |
+
BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt')
|
24 |
|
25 |
|
26 |
class Transform:
|
|
|
122 |
for exp in inputs:
|
123 |
if exp in ['depth', 'normal', 'edge']: # remap to -1 to 1 range
|
124 |
inputs[exp] = 2 * (inputs[exp] - inputs[exp].min()) / (inputs[exp].max() - inputs[exp].min() + eps) - 1
|
125 |
+
inputs[exp] = inputs[exp].half()
|
126 |
+
|
127 |
elif exp == 'seg_coco': # in-paint with CLIP features
|
128 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
129 |
for l in inputs[exp].unique():
|
|
|
131 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
132 |
else:
|
133 |
text_emb[:, (inputs[exp][0] == l)] = COCO_FEATURES[l].unsqueeze(-1)
|
134 |
+
inputs[exp] = text_emb.half()
|
135 |
|
136 |
elif exp == 'seg_ade': # in-paint with CLIP features
|
137 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
|
|
140 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
141 |
else:
|
142 |
text_emb[:, (inputs[exp][0] == l)] = ADE_FEATURES[l].unsqueeze(-1)
|
143 |
+
inputs[exp] = text_emb.half()
|
144 |
|
145 |
elif exp == 'obj_detection': # in-paint with CLIP features
|
146 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
|
|
150 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
151 |
else:
|
152 |
text_emb[:, (inputs[exp][0] == l)] = DETECTION_FEATURES[label_map[str(l.item())]].unsqueeze(-1)
|
153 |
+
inputs[exp] = {'label': text_emb.half(), 'instance': inputs[exp].half()}
|
154 |
|
155 |
elif exp == 'ocr_detection': # in-paint with CLIP features
|
156 |
text_emb = torch.empty([64, *inputs[exp].shape[1:]])
|
|
|
160 |
text_emb[:, (inputs[exp][0] == l)] = BACKGROUND_FEATURES.unsqueeze(-1)
|
161 |
else:
|
162 |
text_emb[:, (inputs[exp][0] == l)] = label_map[l.item()]['features'].unsqueeze(-1)
|
163 |
+
inputs[exp] = text_emb.half()
|
164 |
return inputs
|
165 |
|
166 |
|
prismer/experts/depth/generate_dataset.py
CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
-
def __init__(self,
|
18 |
-
self.data_path = data_path
|
19 |
self.transform = transform
|
20 |
-
|
21 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
22 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
23 |
|
24 |
def __len__(self):
|
25 |
return len(self.data_list)
|
@@ -29,4 +27,4 @@ class Dataset(Dataset):
|
|
29 |
image = Image.open(image_path).convert('RGB')
|
30 |
img_size = [image.size[0], image.size[1]]
|
31 |
image = self.transform(image)
|
32 |
-
return image, image_path, img_size
|
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
+
def __init__(self, config, transform):
|
18 |
+
self.data_path = config['data_path']
|
19 |
self.transform = transform
|
20 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
21 |
|
22 |
def __len__(self):
|
23 |
return len(self.data_list)
|
|
|
27 |
image = Image.open(image_path).convert('RGB')
|
28 |
img_size = [image.size[0], image.size[1]]
|
29 |
image = self.transform(image)
|
30 |
+
return image.half(), image_path, img_size
|
prismer/experts/edge/generate_dataset.py
CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
-
def __init__(self,
|
18 |
-
self.data_path = data_path
|
19 |
self.transform = transform
|
20 |
-
|
21 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
22 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
23 |
|
24 |
def __len__(self):
|
25 |
return len(self.data_list)
|
@@ -29,4 +27,4 @@ class Dataset(Dataset):
|
|
29 |
image = Image.open(image_path).convert('RGB')
|
30 |
img_size = [image.size[0], image.size[1]]
|
31 |
image = self.transform(image)
|
32 |
-
return torch.flip(image, dims=(0, )) * 255., image_path, img_size
|
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
+
def __init__(self, config, transform):
|
18 |
+
self.data_path = config['data_path']
|
19 |
self.transform = transform
|
20 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
21 |
|
22 |
def __len__(self):
|
23 |
return len(self.data_list)
|
|
|
27 |
image = Image.open(image_path).convert('RGB')
|
28 |
img_size = [image.size[0], image.size[1]]
|
29 |
image = self.transform(image)
|
30 |
+
return torch.flip(image.half(), dims=(0, )) * 255., image_path, img_size
|
prismer/experts/generate_depth.py
CHANGED
@@ -21,11 +21,10 @@ model, transform = load_expert_model(task='depth')
|
|
21 |
accelerator = Accelerator(mixed_precision='fp16')
|
22 |
|
23 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
24 |
-
data_path = config['data_path']
|
25 |
save_path = os.path.join(config['save_path'], 'depth')
|
26 |
|
27 |
batch_size = 64
|
28 |
-
dataset = Dataset(
|
29 |
data_loader = torch.utils.data.DataLoader(
|
30 |
dataset=dataset,
|
31 |
batch_size=batch_size,
|
|
|
21 |
accelerator = Accelerator(mixed_precision='fp16')
|
22 |
|
23 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
|
24 |
save_path = os.path.join(config['save_path'], 'depth')
|
25 |
|
26 |
batch_size = 64
|
27 |
+
dataset = Dataset(config, transform)
|
28 |
data_loader = torch.utils.data.DataLoader(
|
29 |
dataset=dataset,
|
30 |
batch_size=batch_size,
|
prismer/experts/generate_edge.py
CHANGED
@@ -23,11 +23,10 @@ model, transform = load_expert_model(task='edge')
|
|
23 |
accelerator = Accelerator(mixed_precision='fp16')
|
24 |
|
25 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
26 |
-
data_path = config['data_path']
|
27 |
save_path = os.path.join(config['save_path'], 'edge')
|
28 |
|
29 |
batch_size = 64
|
30 |
-
dataset = Dataset(
|
31 |
data_loader = torch.utils.data.DataLoader(
|
32 |
dataset=dataset,
|
33 |
batch_size=batch_size,
|
|
|
23 |
accelerator = Accelerator(mixed_precision='fp16')
|
24 |
|
25 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
|
26 |
save_path = os.path.join(config['save_path'], 'edge')
|
27 |
|
28 |
batch_size = 64
|
29 |
+
dataset = Dataset(config, transform)
|
30 |
data_loader = torch.utils.data.DataLoader(
|
31 |
dataset=dataset,
|
32 |
batch_size=batch_size,
|
prismer/experts/generate_normal.py
CHANGED
@@ -23,11 +23,10 @@ model, transform = load_expert_model(task='normal')
|
|
23 |
accelerator = Accelerator(mixed_precision='fp16')
|
24 |
|
25 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
26 |
-
data_path = config['data_path']
|
27 |
save_path = os.path.join(config['save_path'], 'normal')
|
28 |
|
29 |
batch_size = 64
|
30 |
-
dataset = CustomDataset(
|
31 |
data_loader = torch.utils.data.DataLoader(
|
32 |
dataset=dataset,
|
33 |
batch_size=batch_size,
|
|
|
23 |
accelerator = Accelerator(mixed_precision='fp16')
|
24 |
|
25 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
|
26 |
save_path = os.path.join(config['save_path'], 'normal')
|
27 |
|
28 |
batch_size = 64
|
29 |
+
dataset = CustomDataset(config, transform)
|
30 |
data_loader = torch.utils.data.DataLoader(
|
31 |
dataset=dataset,
|
32 |
batch_size=batch_size,
|
prismer/experts/generate_objdet.py
CHANGED
@@ -26,9 +26,8 @@ config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
26 |
data_path = config['data_path']
|
27 |
save_path = config['save_path']
|
28 |
|
29 |
-
depth_path = os.path.join(save_path, 'depth', data_path.split('/')[-1])
|
30 |
batch_size = 32
|
31 |
-
dataset = Dataset(
|
32 |
data_loader = torch.utils.data.DataLoader(
|
33 |
dataset=dataset,
|
34 |
batch_size=batch_size,
|
|
|
26 |
data_path = config['data_path']
|
27 |
save_path = config['save_path']
|
28 |
|
|
|
29 |
batch_size = 32
|
30 |
+
dataset = Dataset(config, transform)
|
31 |
data_loader = torch.utils.data.DataLoader(
|
32 |
dataset=dataset,
|
33 |
batch_size=batch_size,
|
prismer/experts/generate_ocrdet.py
CHANGED
@@ -27,11 +27,10 @@ accelerator = Accelerator(mixed_precision='fp16')
|
|
27 |
pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
|
28 |
|
29 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
30 |
-
data_path = config['data_path']
|
31 |
save_path = os.path.join(config['save_path'], 'ocr_detection')
|
32 |
|
33 |
batch_size = 32
|
34 |
-
dataset = Dataset(
|
35 |
data_loader = torch.utils.data.DataLoader(
|
36 |
dataset=dataset,
|
37 |
batch_size=batch_size,
|
|
|
27 |
pca_clip = pk.load(open('dataset/clip_pca.pkl', 'rb'))
|
28 |
|
29 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
|
30 |
save_path = os.path.join(config['save_path'], 'ocr_detection')
|
31 |
|
32 |
batch_size = 32
|
33 |
+
dataset = Dataset(config, transform)
|
34 |
data_loader = torch.utils.data.DataLoader(
|
35 |
dataset=dataset,
|
36 |
batch_size=batch_size,
|
prismer/experts/generate_segmentation.py
CHANGED
@@ -21,11 +21,10 @@ model, transform = load_expert_model(task='seg_coco')
|
|
21 |
accelerator = Accelerator(mixed_precision='fp16')
|
22 |
|
23 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
24 |
-
data_path = config['data_path']
|
25 |
save_path = os.path.join(config['save_path'], 'seg_coco')
|
26 |
|
27 |
batch_size = 4
|
28 |
-
dataset = Dataset(
|
29 |
data_loader = torch.utils.data.DataLoader(
|
30 |
dataset=dataset,
|
31 |
batch_size=batch_size,
|
|
|
21 |
accelerator = Accelerator(mixed_precision='fp16')
|
22 |
|
23 |
config = yaml.load(open('configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
|
|
24 |
save_path = os.path.join(config['save_path'], 'seg_coco')
|
25 |
|
26 |
batch_size = 4
|
27 |
+
dataset = Dataset(config, transform)
|
28 |
data_loader = torch.utils.data.DataLoader(
|
29 |
dataset=dataset,
|
30 |
batch_size=batch_size,
|
prismer/experts/model_bank.py
CHANGED
@@ -131,6 +131,8 @@ def load_expert_model(task=None):
|
|
131 |
model = None
|
132 |
transform = None
|
133 |
|
|
|
|
|
134 |
model.eval()
|
135 |
return model, transform
|
136 |
|
|
|
131 |
model = None
|
132 |
transform = None
|
133 |
|
134 |
+
if 'seg' not in task:
|
135 |
+
model = model.half()
|
136 |
model.eval()
|
137 |
return model, transform
|
138 |
|
prismer/experts/normal/generate_dataset.py
CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
14 |
|
15 |
|
16 |
class CustomDataset(Dataset):
|
17 |
-
def __init__(self,
|
18 |
-
self.data_path = data_path
|
19 |
self.transform = transform
|
20 |
-
|
21 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
22 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
23 |
|
24 |
def __len__(self):
|
25 |
return len(self.data_list)
|
@@ -29,6 +27,6 @@ class CustomDataset(Dataset):
|
|
29 |
image = Image.open(image_path).convert('RGB')
|
30 |
img_size = [image.size[0], image.size[1]]
|
31 |
image = self.transform(image)
|
32 |
-
return image, image_path, img_size
|
33 |
|
34 |
|
|
|
14 |
|
15 |
|
16 |
class CustomDataset(Dataset):
|
17 |
+
def __init__(self, config, transform):
|
18 |
+
self.data_path = config['data_path']
|
19 |
self.transform = transform
|
20 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
21 |
|
22 |
def __len__(self):
|
23 |
return len(self.data_list)
|
|
|
27 |
image = Image.open(image_path).convert('RGB')
|
28 |
img_size = [image.size[0], image.size[1]]
|
29 |
image = self.transform(image)
|
30 |
+
return image.half(), image_path, img_size
|
31 |
|
32 |
|
prismer/experts/obj_detection/generate_dataset.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import glob
|
|
|
8 |
import torch
|
9 |
|
10 |
from torch.utils.data import Dataset
|
@@ -15,13 +16,11 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
15 |
|
16 |
|
17 |
class Dataset(Dataset):
|
18 |
-
def __init__(self,
|
19 |
-
self.data_path = data_path
|
20 |
-
self.depth_path =
|
21 |
self.transform = transform
|
22 |
-
|
23 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
24 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
25 |
|
26 |
def __len__(self):
|
27 |
return len(self.data_list)
|
@@ -43,7 +42,7 @@ class Dataset(Dataset):
|
|
43 |
depth = self.transform(depth)
|
44 |
depth = torch.tensor(np.array(depth)).float() / 255.
|
45 |
img_size = image.shape
|
46 |
-
return {"image": image, "height": img_size[1], "width": img_size[2],
|
47 |
"true_height": true_img_size[0], "true_width": true_img_size[1],
|
48 |
'image_path': image_path, 'depth': depth}
|
49 |
|
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import glob
|
8 |
+
import os
|
9 |
import torch
|
10 |
|
11 |
from torch.utils.data import Dataset
|
|
|
16 |
|
17 |
|
18 |
class Dataset(Dataset):
|
19 |
+
def __init__(self, config, transform):
|
20 |
+
self.data_path = config['data_path']
|
21 |
+
self.depth_path = os.path.join(config['save_path'], 'depth', self.data_path.split('/')[-1])
|
22 |
self.transform = transform
|
23 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
24 |
|
25 |
def __len__(self):
|
26 |
return len(self.data_list)
|
|
|
42 |
depth = self.transform(depth)
|
43 |
depth = torch.tensor(np.array(depth)).float() / 255.
|
44 |
img_size = image.shape
|
45 |
+
return {"image": image.half(), "height": img_size[1], "width": img_size[2],
|
46 |
"true_height": true_img_size[0], "true_width": true_img_size[1],
|
47 |
'image_path': image_path, 'depth': depth}
|
48 |
|
prismer/experts/ocr_detection/generate_dataset.py
CHANGED
@@ -14,12 +14,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
-
def __init__(self,
|
18 |
-
self.data_path = data_path
|
19 |
self.transform = transform
|
20 |
-
|
21 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
22 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
23 |
|
24 |
def __len__(self):
|
25 |
return len(self.data_list)
|
@@ -30,7 +28,7 @@ class Dataset(Dataset):
|
|
30 |
|
31 |
image, scale_w, scale_h, original_w, original_h = resize(original_image)
|
32 |
image = self.transform(image)
|
33 |
-
return image, image_path, scale_w, scale_h, original_w, original_h
|
34 |
|
35 |
|
36 |
def resize(im):
|
|
|
14 |
|
15 |
|
16 |
class Dataset(Dataset):
|
17 |
+
def __init__(self, config, transform):
|
18 |
+
self.data_path = config['data_path']
|
19 |
self.transform = transform
|
20 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
21 |
|
22 |
def __len__(self):
|
23 |
return len(self.data_list)
|
|
|
28 |
|
29 |
image, scale_w, scale_h, original_w, original_h = resize(original_image)
|
30 |
image = self.transform(image)
|
31 |
+
return image.half(), image_path, scale_w, scale_h, original_w, original_h
|
32 |
|
33 |
|
34 |
def resize(im):
|
prismer/experts/segmentation/generate_dataset.py
CHANGED
@@ -16,12 +16,10 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
|
|
16 |
|
17 |
|
18 |
class Dataset(Dataset):
|
19 |
-
def __init__(self,
|
20 |
-
self.data_path = data_path
|
21 |
self.transform = transform
|
22 |
-
|
23 |
-
self.data_list = [data for f in data_folders for data in glob.glob(f + '*.JPEG')]
|
24 |
-
self.data_list += [data for f in data_folders for data in glob.glob(f + '*.jpg')]
|
25 |
|
26 |
def __len__(self):
|
27 |
return len(self.data_list)
|
|
|
16 |
|
17 |
|
18 |
class Dataset(Dataset):
|
19 |
+
def __init__(self, config, transform):
|
20 |
+
self.data_path = config['data_path']
|
21 |
self.transform = transform
|
22 |
+
self.data_list = [f'helpers/images/{config["im_name"]}.jpg']
|
|
|
|
|
23 |
|
24 |
def __len__(self):
|
25 |
return len(self.data_list)
|
prismer/helpers/images/COCO_test2015_000000000014.jpg
DELETED
Binary file (169 kB)
|
|
prismer/helpers/images/COCO_test2015_000000000016.jpg
DELETED
Binary file (231 kB)
|
|
prismer/helpers/images/COCO_test2015_000000000019.jpg
DELETED
Binary file (285 kB)
|
|
prismer/helpers/images/COCO_test2015_000000000128.jpg
DELETED
Binary file (212 kB)
|
|
prismer/helpers/images/COCO_test2015_000000000155.jpg
DELETED
Binary file (79.7 kB)
|
|
prismer/helpers/intro.png
DELETED
Binary file (405 kB)
|
|
prismer/model/prismer.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import json
|
|
|
8 |
import torch.nn as nn
|
9 |
|
10 |
from model.modules.vit import load_encoder
|
@@ -12,6 +13,9 @@ from model.modules.roberta import load_decoder
|
|
12 |
from transformers import RobertaTokenizer, RobertaConfig
|
13 |
|
14 |
|
|
|
|
|
|
|
15 |
class Prismer(nn.Module):
|
16 |
def __init__(self, config):
|
17 |
super().__init__()
|
@@ -26,7 +30,7 @@ class Prismer(nn.Module):
|
|
26 |
elif exp in ['obj_detection', 'ocr_detection']:
|
27 |
self.experts[exp] = 64
|
28 |
|
29 |
-
prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']]
|
30 |
roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
|
31 |
|
32 |
self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
|
@@ -35,7 +39,7 @@ class Prismer(nn.Module):
|
|
35 |
|
36 |
self.prepare_to_train(config['freeze'])
|
37 |
self.ignored_modules = self.get_ignored_modules(config['freeze'])
|
38 |
-
|
39 |
def prepare_to_train(self, mode='none'):
|
40 |
for name, params in self.named_parameters():
|
41 |
if mode == 'freeze_lang':
|
|
|
5 |
# https://github.com/NVlabs/prismer/blob/main/LICENSE
|
6 |
|
7 |
import json
|
8 |
+
import pathlib
|
9 |
import torch.nn as nn
|
10 |
|
11 |
from model.modules.vit import load_encoder
|
|
|
13 |
from transformers import RobertaTokenizer, RobertaConfig
|
14 |
|
15 |
|
16 |
+
cur_dir = pathlib.Path(__file__).parent
|
17 |
+
|
18 |
+
|
19 |
class Prismer(nn.Module):
|
20 |
def __init__(self, config):
|
21 |
super().__init__()
|
|
|
30 |
elif exp in ['obj_detection', 'ocr_detection']:
|
31 |
self.experts[exp] = 64
|
32 |
|
33 |
+
prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']]
|
34 |
roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model'])
|
35 |
|
36 |
self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name'])
|
|
|
39 |
|
40 |
self.prepare_to_train(config['freeze'])
|
41 |
self.ignored_modules = self.get_ignored_modules(config['freeze'])
|
42 |
+
|
43 |
def prepare_to_train(self, mode='none'):
|
44 |
for name, params in self.named_parameters():
|
45 |
if mode == 'freeze_lang':
|
prismer_model.py
CHANGED
@@ -7,6 +7,12 @@ import shlex
|
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import sys
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
import cv2
|
12 |
import torch
|
@@ -55,27 +61,43 @@ def run_expert(expert_name: str):
|
|
55 |
check=True)
|
56 |
|
57 |
|
58 |
-
def
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
image_dir.mkdir(parents=True, exist_ok=True)
|
63 |
-
out_path = image_dir / 'image.jpg'
|
64 |
-
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
|
65 |
|
66 |
-
# expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
|
67 |
-
# run_expert('depth')
|
68 |
-
# with concurrent.futures.ProcessPoolExecutor() as executor:
|
69 |
-
# executor.map(run_expert, expert_names)
|
70 |
-
|
71 |
-
# no parallelization just to be safe
|
72 |
-
expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
|
73 |
-
for exp in expert_names:
|
74 |
-
run_expert(exp)
|
75 |
|
|
|
|
|
|
|
76 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
77 |
-
results = [pathlib.Path('prismer/helpers/labels') / key / 'helpers/images/
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
|
81 |
class Model:
|
@@ -126,20 +148,28 @@ class Model:
|
|
126 |
len(model.expert_encoder.positional_embedding))
|
127 |
|
128 |
model.load_state_dict(state_dict)
|
|
|
129 |
model.eval()
|
130 |
|
131 |
self.config = config
|
132 |
-
self.model = model
|
133 |
self.tokenizer = model.tokenizer
|
134 |
self.exp_name = exp_name
|
135 |
self.mode = mode
|
136 |
|
137 |
@torch.inference_mode()
|
138 |
-
def run_caption_model(self, exp_name: str) -> str:
|
139 |
self.set_model(exp_name, 'caption')
|
|
|
140 |
_, test_dataset = create_dataset('caption', self.config)
|
141 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
142 |
experts, _ = next(iter(test_loader))
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
captions = self.model(experts, train=False, prefix=self.config['prefix'])
|
144 |
captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
145 |
caption = captions.to(experts['rgb'].device)[0]
|
@@ -148,17 +178,23 @@ class Model:
|
|
148 |
return caption
|
149 |
|
150 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
151 |
-
|
152 |
-
caption = self.run_caption_model(model_name)
|
153 |
-
|
154 |
-
return caption, *out_paths
|
155 |
|
156 |
@torch.inference_mode()
|
157 |
-
def run_vqa_model(self, exp_name: str, question: str) -> str:
|
158 |
self.set_model(exp_name, 'vqa')
|
|
|
159 |
_, test_dataset = create_dataset('caption', self.config)
|
160 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
161 |
experts, _ = next(iter(test_loader))
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
question = pre_question(question)
|
163 |
answer = self.model(experts, [question], train=False, inference='generate')
|
164 |
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
@@ -168,7 +204,6 @@ class Model:
|
|
168 |
return answer
|
169 |
|
170 |
def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
|
171 |
-
|
172 |
-
answer = self.run_vqa_model(model_name, question)
|
173 |
-
|
174 |
-
return answer, *out_paths
|
|
|
7 |
import shutil
|
8 |
import subprocess
|
9 |
import sys
|
10 |
+
import hashlib
|
11 |
+
from typing import Tuple
|
12 |
+
try:
|
13 |
+
import ruamel_yaml as yaml
|
14 |
+
except ModuleNotFoundError:
|
15 |
+
import ruamel.yaml as yaml
|
16 |
|
17 |
import cv2
|
18 |
import torch
|
|
|
61 |
check=True)
|
62 |
|
63 |
|
64 |
+
def compute_md5(image_path: str) -> str:
|
65 |
+
with open(image_path, 'rb') as f:
|
66 |
+
s = f.read()
|
67 |
+
return hashlib.md5(s).hexdigest()
|
|
|
|
|
|
|
68 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
+
def run_experts(image_path: str) -> Tuple[str, Tuple[str, ...]]:
|
71 |
+
im_name = compute_md5(image_path)
|
72 |
+
out_path = submodule_dir / 'helpers' / 'images' / f'{im_name}.jpg'
|
73 |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
|
74 |
+
results = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}.png' for key in keys]
|
75 |
+
results_pretty = [pathlib.Path('prismer/helpers/labels') / key / f'helpers/images/{im_name}_p.png' for key in keys]
|
76 |
+
out_paths = tuple(path.as_posix() for path in results)
|
77 |
+
pretty_paths = tuple(path.as_posix() for path in results_pretty)
|
78 |
+
|
79 |
+
config = yaml.load(open('prismer/configs/experts.yaml', 'r'), Loader=yaml.Loader)
|
80 |
+
config['im_name'] = im_name
|
81 |
+
with open('prismer/configs/experts.yaml', 'w') as yaml_file:
|
82 |
+
yaml.dump(config, yaml_file, default_flow_style=False)
|
83 |
+
|
84 |
+
if not os.path.exists(out_paths[0]):
|
85 |
+
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
|
86 |
+
|
87 |
+
# paralleled inference
|
88 |
+
expert_names = ['edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
|
89 |
+
run_expert('depth')
|
90 |
+
with concurrent.futures.ProcessPoolExecutor() as executor:
|
91 |
+
executor.map(run_expert, expert_names)
|
92 |
+
executor.shutdown(wait=True)
|
93 |
+
|
94 |
+
# no parallelization just to be safe
|
95 |
+
# expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
|
96 |
+
# for exp in expert_names:
|
97 |
+
# run_expert(exp)
|
98 |
+
|
99 |
+
label_prettify(image_path, out_paths)
|
100 |
+
return im_name, pretty_paths
|
101 |
|
102 |
|
103 |
class Model:
|
|
|
148 |
len(model.expert_encoder.positional_embedding))
|
149 |
|
150 |
model.load_state_dict(state_dict)
|
151 |
+
model = model.half()
|
152 |
model.eval()
|
153 |
|
154 |
self.config = config
|
155 |
+
self.model = model.to('cuda:0')
|
156 |
self.tokenizer = model.tokenizer
|
157 |
self.exp_name = exp_name
|
158 |
self.mode = mode
|
159 |
|
160 |
@torch.inference_mode()
|
161 |
+
def run_caption_model(self, exp_name: str, im_name: str) -> str:
|
162 |
self.set_model(exp_name, 'caption')
|
163 |
+
self.config['im_name'] = im_name
|
164 |
_, test_dataset = create_dataset('caption', self.config)
|
165 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
166 |
experts, _ = next(iter(test_loader))
|
167 |
+
for exp in experts:
|
168 |
+
if exp == 'obj_detection':
|
169 |
+
experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
|
170 |
+
experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
|
171 |
+
else:
|
172 |
+
experts[exp] = experts[exp].to('cuda:0')
|
173 |
captions = self.model(experts, train=False, prefix=self.config['prefix'])
|
174 |
captions = self.tokenizer(captions, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
175 |
caption = captions.to(experts['rgb'].device)[0]
|
|
|
178 |
return caption
|
179 |
|
180 |
def run_caption(self, image_path: str, model_name: str) -> tuple[str | None, ...]:
|
181 |
+
im_name, pretty_paths = run_experts(image_path)
|
182 |
+
caption = self.run_caption_model(model_name, im_name)
|
183 |
+
return caption, *pretty_paths
|
|
|
184 |
|
185 |
@torch.inference_mode()
|
186 |
+
def run_vqa_model(self, exp_name: str, im_name: str, question: str) -> str:
|
187 |
self.set_model(exp_name, 'vqa')
|
188 |
+
self.config['im_name'] = im_name
|
189 |
_, test_dataset = create_dataset('caption', self.config)
|
190 |
test_loader = create_loader(test_dataset, batch_size=1, num_workers=4, train=False)
|
191 |
experts, _ = next(iter(test_loader))
|
192 |
+
for exp in experts:
|
193 |
+
if exp == 'obj_detection':
|
194 |
+
experts[exp]['label'] = experts['obj_detection']['label'].to('cuda:0')
|
195 |
+
experts[exp]['instance'] = experts['obj_detection']['instance'].to('cuda:0')
|
196 |
+
else:
|
197 |
+
experts[exp] = experts[exp].to('cuda:0')
|
198 |
question = pre_question(question)
|
199 |
answer = self.model(experts, [question], train=False, inference='generate')
|
200 |
answer = self.tokenizer(answer, max_length=30, padding='max_length', return_tensors='pt').input_ids
|
|
|
204 |
return answer
|
205 |
|
206 |
def run_vqa(self, image_path: str, model_name: str, question: str) -> tuple[str | None, ...]:
|
207 |
+
im_name, pretty_paths = run_experts(image_path)
|
208 |
+
answer = self.run_vqa_model(model_name, im_name, question)
|
209 |
+
return answer, *pretty_paths
|
|
requirements.txt
CHANGED
@@ -6,7 +6,7 @@ fire==0.5.0
|
|
6 |
geffnet==1.0.2
|
7 |
git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13
|
8 |
git+https://github.com/openai/CLIP.git@a9b1bf5
|
9 |
-
gradio==3.
|
10 |
huggingface-hub==0.12.1
|
11 |
opencv-python-headless==4.7.0.72
|
12 |
pyclipper==1.3.0.post4
|
|
|
6 |
geffnet==1.0.2
|
7 |
git+https://github.com/facebookresearch/detectron2.git@5aeb252b194b93dc2879b4ac34bc51a31b5aee13
|
8 |
git+https://github.com/openai/CLIP.git@a9b1bf5
|
9 |
+
gradio==3.24.1
|
10 |
huggingface-hub==0.12.1
|
11 |
opencv-python-headless==4.7.0.72
|
12 |
pyclipper==1.3.0.post4
|