prismer / prismer_model.py
shikunl's picture
Reset
818a4f8
raw
history blame
4.77 kB
from __future__ import annotations
import os
import pathlib
import shlex
import shutil
import subprocess
import sys
import cv2
import torch
repo_dir = pathlib.Path(__file__).parent
submodule_dir = repo_dir / 'prismer'
sys.path.insert(0, submodule_dir.as_posix())
from dataset import create_dataset, create_loader
from model.prismer_caption import PrismerCaption
def download_models() -> None:
if not pathlib.Path('prismer/experts/expert_weights/').exists():
subprocess.run(shlex.split('python download_checkpoints.py --download_experts=True'), cwd='prismer')
model_names = [
# 'vqa_prismer_base',
# 'vqa_prismer_large',
'caption_prismer_base',
'caption_prismer_large',
]
for model_name in model_names:
if pathlib.Path(f'prismer/logging/{model_name}').exists():
continue
subprocess.run(shlex.split(f'python download_checkpoints.py --download_models={model_name}'), cwd='prismer')
def build_deformable_conv() -> None:
subprocess.run(shlex.split('sh make.sh'), cwd='prismer/experts/segmentation/mask2former/modeling/pixel_decoder/ops')
def run_experts(image_path: str) -> tuple[str | None, ...]:
helper_dir = submodule_dir / 'helpers'
shutil.rmtree(helper_dir, ignore_errors=True)
image_dir = helper_dir / 'images'
image_dir.mkdir(parents=True, exist_ok=True)
out_path = image_dir / 'image.jpg'
cv2.imwrite(out_path.as_posix(), cv2.imread(image_path))
expert_names = ['depth', 'edge', 'normal', 'objdet', 'ocrdet', 'segmentation']
for expert_name in expert_names:
env = os.environ.copy()
if 'PYTHONPATH' in env:
env['PYTHONPATH'] = f'{submodule_dir.as_posix()}:{env["PYTHONPATH"]}'
else:
env['PYTHONPATH'] = submodule_dir.as_posix()
subprocess.run(
shlex.split(f'python experts/generate_{expert_name}.py'),
cwd='prismer',
env=env,
check=True)
keys = ['depth', 'edge', 'normal', 'seg_coco', 'obj_detection', 'ocr_detection']
results = [
pathlib.Path('prismer/helpers/labels') / key /
'helpers/images/image.png' for key in keys
]
return tuple(path.as_posix() if path.exists() else None
for path in results)
class Model:
def __init__(self):
self.config = None
self.model = None
self.tokenizer = None
self.exp_name = ''
def set_model(self, exp_name: str) -> None:
if exp_name == self.exp_name:
return
config = {
'dataset':
'demo',
'data_path':
'prismer/helpers',
'label_path':
'prismer/helpers/labels',
'experts': [
'depth',
'normal',
'seg_coco',
'edge',
'obj_detection',
'ocr_detection',
],
'image_resolution':
480,
'prismer_model':
'prismer_base',
'freeze':
'freeze_vision',
'prefix':
'A picture of',
}
model = PrismerCaption(config)
state_dict = torch.load(
f'prismer/logging/caption_{exp_name}/pytorch_model.bin',
map_location='cuda:0')
model.load_state_dict(state_dict)
model.eval()
tokenizer = model.tokenizer
self.config = config
self.model = model
self.tokenizer = tokenizer
self.exp_name = exp_name
@torch.inference_mode()
def run_caption_model(self, exp_name: str) -> str:
self.set_model(exp_name)
_, test_dataset = create_dataset('caption', self.config)
test_loader = create_loader(test_dataset,
batch_size=1,
num_workers=4,
train=False)
experts, _ = next(iter(test_loader))
captions = self.model(experts,
train=False,
prefix=self.config['prefix'])
captions = self.tokenizer(captions,
max_length=30,
padding='max_length',
return_tensors='pt').input_ids
caption = captions.to(experts['rgb'].device)[0]
caption = self.tokenizer.decode(caption, skip_special_tokens=True)
caption = caption.capitalize() + '.'
return caption
def run_caption(self, image_path: str,
model_name: str) -> tuple[str | None, ...]:
out_paths = run_experts(image_path)
caption = self.run_caption_model(model_name)
return caption, *out_paths