import hydra import pyrootutils import os import torch from omegaconf import OmegaConf import json from typing import Optional import transformers from PIL import Image from torchvision.transforms.functional import InterpolationMode pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) BOI_TOKEN = '' EOI_TOKEN = '' IMG_TOKEN = '' IMG_FLAG = '' NUM_IMG_TOKNES = 32 NUM_IMG_CODES = 8192 image_id_shift = 32000 def generate(tokenizer, input_tokens, generation_config, model): input_ids = tokenizer(input_tokens, add_special_tokens=False, return_tensors='pt').input_ids input_ids = input_ids.to("cuda") generate_ids = model.generate( input_ids=input_ids, **generation_config ) generate_ids = generate_ids[0][input_ids.shape[1]:] return generate_ids def decode_image_text(generate_ids, tokenizer, save_path=None): boi_list = torch.where(generate_ids == tokenizer(BOI_TOKEN, add_special_tokens=False).input_ids[0])[0] eoi_list = torch.where(generate_ids == tokenizer(EOI_TOKEN, add_special_tokens=False).input_ids[0])[0] if len(boi_list) == 0 and len(eoi_list) == 0: text_ids = generate_ids texts = tokenizer.decode(text_ids, skip_special_tokens=True) print(texts) else: boi_index = boi_list[0] eoi_index = eoi_list[0] text_ids = generate_ids[:boi_index] if len(text_ids) != 0: texts = tokenizer.decode(text_ids, skip_special_tokens=True) print(texts) image_ids = (generate_ids[boi_index+1:eoi_index] - image_id_shift).reshape(1,-1) images = tokenizer.decode_image(image_ids) images[0].save(save_path) device = "cuda" tokenizer_cfg_path = 'configs/tokenizer/seed_llama_tokenizer.yaml' tokenizer_cfg = OmegaConf.load(tokenizer_cfg_path) tokenizer = hydra.utils.instantiate(tokenizer_cfg, device=device, load_diffusion=True) transform_cfg_path = 'configs/transform/clip_transform.yaml' transform_cfg = OmegaConf.load(transform_cfg_path) transform = hydra.utils.instantiate(transform_cfg) model_cfg = OmegaConf.load('configs/llm/seed_llama_8b.yaml') model = hydra.utils.instantiate(model_cfg, torch_dtype=torch.float16) model = model.eval().to(device) generation_config = { 'temperature': 1.0, 'num_beams': 1, 'max_new_tokens': 512, 'top_p': 0.5, 'do_sample': True } s_token = "USER:" e_token = "ASSISTANT:" sep = "\n" ### visual question answering image_path = "images/cat.jpg" image = Image.open(image_path).convert('RGB') image_tensor = transform(image).to(device) img_ids = tokenizer.encode_image(image_torch=image_tensor) img_ids = img_ids.view(-1).cpu().numpy() img_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(item) for item in img_ids]) + EOI_TOKEN question = "What is this animal?" input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + question + sep + e_token generate_ids = generate(tokenizer, input_tokens, generation_config, model) decode_image_text(generate_ids, tokenizer) ### text-to-image generation prompt = "Can you generate an image of a dog on the green grass?" input_tokens = tokenizer.bos_token + s_token + " " + prompt + sep + e_token generate_ids = generate(tokenizer, input_tokens, generation_config, model) save_path = 'dog.jpg' decode_image_text(generate_ids, tokenizer, save_path) ### multimodal prompt image generation instruction = "Can you make the cat wear sunglasses?" input_tokens = tokenizer.bos_token + s_token + " " + img_tokens + instruction + sep + e_token generate_ids = generate(tokenizer, input_tokens, generation_config, model) save_path = 'cat_sunglasses.jpg' decode_image_text(generate_ids, tokenizer, save_path)