Spaces:
Sleeping
Sleeping
Add VQA
Browse files- label_prettify.py +1 -1
- prismer_model.py +8 -6
label_prettify.py
CHANGED
@@ -87,7 +87,7 @@ def ocr_detection_prettify(rgb_path, file_name):
|
|
87 |
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
88 |
|
89 |
plt.imshow(rgb)
|
90 |
-
plt.imshow(
|
91 |
|
92 |
for i in np.unique(ocr_labels)[:-1]:
|
93 |
text_idx_all = np.where(ocr_labels == i)
|
|
|
87 |
ocr_labels_dict = torch.load(file_name.replace('.png', '.pt'))
|
88 |
|
89 |
plt.imshow(rgb)
|
90 |
+
plt.imshow(ocr_labels, cmap='gray', alpha=0.8)
|
91 |
|
92 |
for i in np.unique(ocr_labels)[:-1]:
|
93 |
text_idx_all = np.where(ocr_labels == i)
|
prismer_model.py
CHANGED
@@ -75,11 +75,13 @@ class Model:
|
|
75 |
if exp_name == self.exp_name:
|
76 |
return
|
77 |
|
|
|
78 |
if self.exp_name == 'Prismer-Base':
|
79 |
-
|
80 |
elif self.exp_name == 'Prismer-Large':
|
81 |
-
|
82 |
|
|
|
83 |
if self.mode == 'caption':
|
84 |
config = {
|
85 |
'dataset': 'demo',
|
@@ -87,12 +89,12 @@ class Model:
|
|
87 |
'label_path': 'prismer/helpers/labels',
|
88 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
89 |
'image_resolution': 480,
|
90 |
-
'prismer_model':
|
91 |
'freeze': 'freeze_vision',
|
92 |
'prefix': '',
|
93 |
}
|
94 |
model = PrismerCaption(config)
|
95 |
-
state_dict = torch.load(f'prismer/logging/pretrain_{
|
96 |
|
97 |
elif self.mode == 'vqa':
|
98 |
config = {
|
@@ -101,12 +103,12 @@ class Model:
|
|
101 |
'label_path': 'prismer/helpers/labels',
|
102 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
103 |
'image_resolution': 480,
|
104 |
-
'prismer_model':
|
105 |
'freeze': 'freeze_vision',
|
106 |
}
|
107 |
|
108 |
model = PrismerVQA(config)
|
109 |
-
state_dict = torch.load(f'prismer/logging/vqa_{
|
110 |
|
111 |
model.load_state_dict(state_dict)
|
112 |
model.eval()
|
|
|
75 |
if exp_name == self.exp_name:
|
76 |
return
|
77 |
|
78 |
+
# remap model name
|
79 |
if self.exp_name == 'Prismer-Base':
|
80 |
+
self.exp_name = 'prismer_base'
|
81 |
elif self.exp_name == 'Prismer-Large':
|
82 |
+
self.exp_name = 'prismer_large'
|
83 |
|
84 |
+
# load checkpoints
|
85 |
if self.mode == 'caption':
|
86 |
config = {
|
87 |
'dataset': 'demo',
|
|
|
89 |
'label_path': 'prismer/helpers/labels',
|
90 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
91 |
'image_resolution': 480,
|
92 |
+
'prismer_model': self.exp_name,
|
93 |
'freeze': 'freeze_vision',
|
94 |
'prefix': '',
|
95 |
}
|
96 |
model = PrismerCaption(config)
|
97 |
+
state_dict = torch.load(f'prismer/logging/pretrain_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
|
98 |
|
99 |
elif self.mode == 'vqa':
|
100 |
config = {
|
|
|
103 |
'label_path': 'prismer/helpers/labels',
|
104 |
'experts': ['depth', 'normal', 'seg_coco', 'edge', 'obj_detection', 'ocr_detection'],
|
105 |
'image_resolution': 480,
|
106 |
+
'prismer_model': self.exp_name,
|
107 |
'freeze': 'freeze_vision',
|
108 |
}
|
109 |
|
110 |
model = PrismerVQA(config)
|
111 |
+
state_dict = torch.load(f'prismer/logging/vqa_{self.exp_name}/pytorch_model.bin', map_location='cuda:0')
|
112 |
|
113 |
model.load_state_dict(state_dict)
|
114 |
model.eval()
|