Spaces:
Sleeping
Sleeping
diff --git a/dataset/caption_dataset.py b/dataset/caption_dataset.py | |
index 266fdda..0cc5d3f 100644 | |
--- a/dataset/caption_dataset.py | |
+++ b/dataset/caption_dataset.py | |
class Caption(Dataset): | |
elif self.dataset == 'demo': | |
img_path_split = self.data_list[index]['image'].split('/') | |
img_name = img_path_split[-2] + '/' + img_path_split[-1] | |
- image, labels, labels_info = get_expert_labels('', self.label_path, img_name, 'helpers', self.experts) | |
+ image, labels, labels_info = get_expert_labels('prismer', self.label_path, img_name, 'helpers', self.experts) | |
experts = self.transform(image, labels) | |
experts = post_label_process(experts, labels_info) | |
diff --git a/dataset/utils.py b/dataset/utils.py | |
index b368aac..418358c 100644 | |
--- a/dataset/utils.py | |
+++ b/dataset/utils.py | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import os | |
+import pathlib | |
import re | |
import json | |
import torch | |
import torchvision.transforms as transforms | |
import torchvision.transforms.functional as transforms_f | |
from dataset.randaugment import RandAugment | |
-COCO_FEATURES = torch.load('dataset/coco_features.pt')['features'] | |
-ADE_FEATURES = torch.load('dataset/ade_features.pt')['features'] | |
-DETECTION_FEATURES = torch.load('dataset/detection_features.pt')['features'] | |
-BACKGROUND_FEATURES = torch.load('dataset/background_features.pt') | |
+cur_dir = pathlib.Path(__file__).parent | |
+ | |
+COCO_FEATURES = torch.load(cur_dir / 'coco_features.pt')['features'] | |
+ADE_FEATURES = torch.load(cur_dir / 'ade_features.pt')['features'] | |
+DETECTION_FEATURES = torch.load(cur_dir / 'detection_features.pt')['features'] | |
+BACKGROUND_FEATURES = torch.load(cur_dir / 'background_features.pt') | |
class Transform: | |
diff --git a/model/prismer.py b/model/prismer.py | |
index 080253a..02362a4 100644 | |
--- a/model/prismer.py | |
+++ b/model/prismer.py | |
# https://github.com/NVlabs/prismer/blob/main/LICENSE | |
import json | |
+import pathlib | |
import torch.nn as nn | |
from model.modules.vit import load_encoder | |
from model.modules.roberta import load_decoder | |
from transformers import RobertaTokenizer, RobertaConfig | |
+cur_dir = pathlib.Path(__file__).parent | |
+ | |
+ | |
class Prismer(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
class Prismer(nn.Module): | |
elif exp in ['obj_detection', 'ocr_detection']: | |
self.experts[exp] = 64 | |
- prismer_config = json.load(open('configs/prismer.json', 'r'))[config['prismer_model']] | |
+ prismer_config = json.load(open(f'{cur_dir.parent}/configs/prismer.json', 'r'))[config['prismer_model']] | |
roberta_config = RobertaConfig.from_dict(prismer_config['roberta_model']) | |
self.tokenizer = RobertaTokenizer.from_pretrained(prismer_config['roberta_model']['model_name']) | |
class Prismer(nn.Module): | |
self.prepare_to_train(config['freeze']) | |
self.ignored_modules = self.get_ignored_modules(config['freeze']) | |
- | |
+ | |
def prepare_to_train(self, mode='none'): | |
for name, params in self.named_parameters(): | |
if mode == 'freeze_lang': | |