Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
import pathlib | |
import shlex | |
import shutil | |
import subprocess | |
import sys | |
import cv2 | |
import torch | |
repo_dir = pathlib.Path(__file__).parent | |
submodule_dir = repo_dir / 'prismer' | |
sys.path.insert(0, submodule_dir.as_posix()) | |
from dataset import create_dataset, create_loader | |
from model.prismer_caption import PrismerCaption | |
def download_models() -> None: | |
if not pathlib.Path('prismer/experts/expert_weights/').exists(): | |
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer') | |
model_names = [ | |
# 'vqa_prismer_base', | |
# 'vqa_prismer_large', | |
'caption_prismer_base', | |
'caption_prismer_large', | |
] | |
for model_name in model_names: | |
if pathlib.Path(f'prismer/logging/{model_name}').exists(): | |
continue | |
subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer') | |
def build_deformable_conv() -> None: | |
subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops') | |
def run_experts(image_path: str) -> tuple[str | None, ...]: | |
helper_dir = submodule_dir / 'helpers' | |
shutil.rmtree(helper_dir, ignore_errors=True) | |
image_dir = helper_dir / 'images' | |
image_dir.mkdir(parents=True, exist_ok=True) | |
out_path = image_dir / 'image.jpg' | |
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path)) | |
expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation'] | |
for expert_name in expert_names: | |
env = os.environ.copy() | |
if 'PYTHONPATH' in env: | |
env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}' | |
else: | |
env['PYTHONPATH'] = submodule_dir.as_posix() | |
subprocess.run( | |
shlex.split(f'python experts/generate_{expert_name}.py'), | |
cwd='prismer', | |
env=env, | |
check=True) | |
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection'] | |
results = [ | |
pathlib.Path('prismer/helpers/labels') / key / | |
'helpers/images/image.png' for key in keys | |
] | |
return tuple(path.as_posix() if path.exists() else None | |
for path in results) | |
class Model: | |
def __init__(self): | |
self.config = None | |
self.model = None | |
self.tokenizer = None | |
self.exp_name = '' | |
def set_model(self, exp_name: str) -> None: | |
if exp_name == self.exp_name: | |
return | |
config = { | |
'dataset': | |
'demo', | |
'data_path': | |
'prismer/helpers', | |
'label_path': | |
'prismer/helpers/labels', | |
'experts': [ | |
'depth', | |
'normal', | |
'seg_coco', | |
'edge', | |
'obj_detection', | |
'ocr_detection', | |
], | |
'image_resolution': | |
480, | |
'prismer_model': | |
'prismer_base', | |
'freeze': | |
'freeze_vision', | |
'prefix': | |
'A picture of', | |
} | |
model = PrismerCaption(config) | |
state_dict = torch.load( | |
f'prismer/logging/caption_{exp_name}/pytorch_model.bin', | |
map_location='cuda:0') | |
model.load_state_dict(state_dict) | |
model.eval() | |
tokenizer = model.tokenizer | |
self.config = config | |
self.model = model | |
self.tokenizer = tokenizer | |
self.exp_name = exp_name | |
def run_caption_model(self, exp_name: str) -> str: | |
self.set_model(exp_name) | |
_, test_dataset = create_dataset('caption', self.config) | |
test_loader = create_loader(test_dataset, | |
batch_size=1, | |
num_workers=4, | |
train=False) | |
experts, _ = next(iter(test_loader)) | |
captions = self.model(experts, | |
train=False, | |
prefix=self.config['prefix']) | |
captions = self.tokenizer(captions, | |
max_length=30, | |
padding='max_length', | |
return_tensors='pt').input_ids | |
caption = captions.to(experts['rgb'].device)[0] | |
caption = self.tokenizer.decode(caption, skip_special_tokens=True) | |
caption = caption.capitalize() + '.' | |
return caption | |
def run_caption(self, image_path: str, | |
model_name: str) -> tuple[str | None, ...]: | |
out_paths = run_experts(image_path) | |
caption = self.run_caption_model(model_name) | |
return caption, *out_paths | |