Spaces:
Sleeping
Sleeping
import hydra | |
import pyrootutils | |
import os | |
import torch | |
from omegaconf import OmegaConf | |
from flask import Flask, request | |
import json | |
from typing import Optional | |
import transformers | |
from dataclasses import dataclass, field | |
import io | |
import base64 | |
from PIL import Image | |
import gc | |
pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) | |
BOI_TOKEN = '<img>' | |
EOI_TOKEN = '</img>' | |
IMG_TOKEN = '<img_{:05d}>' | |
IMG_FLAG = '<image>' | |
NUM_IMG_TOKNES = 32 | |
NUM_IMG_CODES = 8192 | |
app = Flask(__name__) | |
def decode_image(encoded_image: str) -> Image: | |
decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) | |
buffer = io.BytesIO(decoded_bytes) | |
image = Image.open(buffer) | |
return image | |
def encode_image(image: Image.Image, format: str = 'PNG') -> str: | |
with io.BytesIO() as buffer: | |
image.save(buffer, format=format) | |
encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') | |
return encoded_image | |
class Arguments: | |
image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) | |
tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
model: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) | |
port: Optional[str] = field(default=80, metadata={"help": "network port"}) | |
llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) | |
tokenizer_device: Optional[str] = field(default='cuda:0', metadata={"help": "tokenizer device"}) | |
offload_encoder: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"}) | |
offload_decoder: Optional[bool] = field(default=False, metadata={"help": "offload image tokenizer"}) | |
parser = transformers.HfArgumentParser(Arguments) | |
args, = parser.parse_args_into_dataclasses() | |
class LLMService: | |
def __init__(self, args) -> None: | |
image_transform_cfg = OmegaConf.load(args.image_transform) | |
tokenizer_cfg = OmegaConf.load(args.tokenizer) | |
model_cfg = OmegaConf.load(args.model) | |
self.image_id_shift = 32000 | |
self.image_transform = hydra.utils.instantiate(image_transform_cfg) | |
model = hydra.utils.instantiate(model_cfg, device_map=args.llm_device).eval() | |
self.model = model | |
print(model.get_memory_footprint()) | |
self.tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=args.tokenizer_device, load_diffusion=True) | |
if args.offload_encoder: | |
self.tokenizer.image_tokenizer.model.visual_encoder.to('cpu') | |
if args.offload_decoder: | |
self.tokenizer.image_tokenizer.diffusion_model.to('cpu') | |
# model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16) | |
# self.model = model.eval().to(args.llm_device) | |
self.llm_device = args.llm_device | |
self.tokenizer_device = args.tokenizer_device | |
self.offload_encoder = args.offload_encoder | |
self.offload_decoder = args.offload_decoder | |
self.boi_token_id = self.tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0] | |
self.eoi_token_id = self.tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0] | |
print('Init Done...') | |
service = LLMService(args) | |
def generate(): | |
request_info = request.get_json() | |
text_list = request_info['text'].split(IMG_FLAG) | |
image_list = request_info['images'] | |
temperature = request_info.get('temperature', 0.7) | |
num_beams = request_info.get('num_beams', 1) | |
max_new_tokens = request_info.get('max_new_tokens', 256) | |
top_p = request_info.get('top_p', 0.5) | |
force_boi = request_info.get('force_boi', False) | |
assert len(text_list) == len(image_list) + 1 | |
if len(image_list) > 0: | |
images_tensor_list = [] | |
images_tensor_indices = [] | |
images_ids_list = [] | |
images_ids_indices = [] | |
for idx, image_item in enumerate(image_list): | |
if isinstance(image_item, str): | |
image = decode_image(image_item) | |
image_tensor = service.image_transform(image) | |
images_tensor_list.append(image_tensor) | |
images_tensor_indices.append(idx) | |
else: | |
images_ids_list.append(image_item) | |
images_ids_indices.append(idx) | |
if len(images_tensor_list) > 0: | |
images_tensor = torch.stack(images_tensor_list, dim=0).to(service.tokenizer_device) | |
if service.offload_encoder: | |
service.tokenizer.image_tokenizer.model.visual_encoder.to(service.tokenizer_device) | |
images_ids_1 = service.tokenizer.encode_image(image_torch=images_tensor).cpu() | |
if args.offload_encoder: | |
service.tokenizer.image_tokenizer.model.visual_encoder.to('cpu') | |
torch.cuda.empty_cache() | |
gc.collect() | |
num_image_ids = images_ids_1.shape[-1] | |
else: | |
num_image_ids = len(images_ids_list[-1]) | |
images_ids_2 = torch.tensor(images_ids_list, dtype=torch.long) | |
images_ids = torch.zeros((len(image_list), num_image_ids), dtype=torch.long) | |
if len(images_tensor_indices) > 0: | |
images_ids[images_tensor_indices, :] = images_ids_1 | |
if len(images_ids_indices) > 0: | |
images_ids[images_ids_indices, :] = images_ids_2 | |
input_text = '' | |
for i in range(images_ids.shape[0]): | |
single_image_ids = images_ids[i].view(-1).tolist() | |
image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in single_image_ids]) + EOI_TOKEN | |
input_text += text_list[i] + image_tokens | |
input_text = service.tokenizer.bos_token + input_text + text_list[-1] | |
images_ids_list = images_ids.tolist() | |
else: | |
input_text = service.tokenizer.bos_token + ''.join(text_list) | |
images_ids_list = [] | |
if force_boi: | |
input_text += BOI_TOKEN | |
print(input_text) | |
input_ids = service.tokenizer(input_text, add_special_tokens=False, return_tensors='pt').input_ids | |
input_ids = input_ids.to(service.llm_device) | |
generation_config = { | |
'temperature': temperature, | |
'num_beams': num_beams, | |
'max_new_tokens': max_new_tokens, | |
'top_p': top_p, | |
'do_sample': True | |
} | |
generate_ids = service.model.generate(input_ids=input_ids, **generation_config) | |
if force_boi: | |
generate_ids = generate_ids[0][input_ids.shape[1] - 1:] | |
else: | |
generate_ids = generate_ids[0][input_ids.shape[1]:] | |
print('generated_ids: ', generate_ids) | |
boi_indices = torch.where(generate_ids == service.boi_token_id)[0].tolist() | |
eoi_indices = torch.where(generate_ids == service.eoi_token_id)[0].tolist() | |
# assert len(boi_indices) == len(eoi_indices) | |
generated_image_base64_list = [] | |
text_mask = torch.ones_like(generate_ids, dtype=torch.bool) | |
error_msg = [] | |
if len(boi_indices) != len(eoi_indices): | |
error_msg.append( | |
f'Num of BOI (begain of image) tokens: {len(boi_indices)} is not equal to EOI(end of image tokens): {len(eoi_indices)}, some image Some images will fail to decode.' | |
) | |
num_images = min(len(boi_indices), len(eoi_indices)) | |
for idx in range(num_images): | |
boi_index, eoi_index = boi_indices[idx], eoi_indices[idx] | |
# for boi_index, eoi_index in zip(boi_indices, eoi_indices): | |
image_ids = generate_ids[boi_index + 1:eoi_index].unsqueeze(0).to(service.tokenizer_device) | |
image_ids = image_ids - service.image_id_shift | |
if image_ids.shape[-1] != NUM_IMG_TOKNES: | |
error_msg.append(f'Len(image_ids) {image_ids.shape[-1]} is not equal to {NUM_IMG_TOKNES}') | |
image_base64 = '' | |
elif (image_ids < 0).any() or (image_ids >= NUM_IMG_CODES).any(): | |
error_msg.append(f'Some image_id out of range: [0, {NUM_IMG_CODES})') | |
image_base64 = '' | |
else: | |
if service.offload_decoder: | |
service.tokenizer.image_tokenizer.diffusion_model.to(service.tokenizer_device) | |
image = service.tokenizer.decode_image(image_ids)[0] | |
if service.offload_decoder: | |
service.tokenizer.image_tokenizer.diffusion_model.to('cpu') | |
torch.cuda.empty_cache() | |
gc.collect() | |
image_base64 = encode_image(image) | |
generated_image_base64_list.append(image_base64) | |
text_mask[boi_index + 1:eoi_index] = False | |
images_ids_list.append(image_ids.view(-1).tolist()) | |
generate_ids = generate_ids[text_mask] | |
# print('generate_ids: ', generate_ids) | |
# generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=True) | |
generate_text = service.tokenizer.decode(generate_ids, skip_special_tokens=False) | |
# print('generate_text before: ', generate_text) | |
generate_text = generate_text.replace(BOI_TOKEN + ' ' + EOI_TOKEN + ' ', IMG_FLAG) | |
generate_text = generate_text.replace(service.tokenizer.eos_token, '') | |
print('generate_text: ', generate_text) | |
return {'text': generate_text, 'images': generated_image_base64_list, 'images_ids': images_ids_list, 'error_msg': error_msg} | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=args.port) | |