import argparse import time import logging import requests import os from PIL import Image from io import BytesIO from PIL import Image import torch from transformers import AutoTokenizer from modeling_tinyllava_elm import TinyLlavaForConditionalGeneration from configuration import * from conversion import * from utils import * def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image def generate( prompt: str, model: str, tokenizer = None, image: str = None, device: str = None, max_new_tokens: int = 1024, num_beams = 1, top_p=None, temperature=0.2 ): if not device: if torch.cuda.is_available() and torch.cuda.device_count(): device = "cuda:0" logging.warning( 'inference device is not set, using cuda:0, %s', torch.cuda.get_device_name(0) ) else: device = 'cpu' logging.warning( ( 'No CUDA device detected, using cpu, ' 'expect slower speeds.' ) ) if 'cuda' in device and not torch.cuda.is_available(): raise ValueError('CUDA device requested but no CUDA device detected.') if isinstance(model, str): checkpoint_path = model # print(f'loading model from {checkpoint_path}...') model = TinyLlavaForConditionalGeneration.from_pretrained( checkpoint_path, torch_dtype=torch.float16, ) # print('model load over') config = model.config if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length, padding_side = config.tokenizer_padding_side) image_processor = model.vision_tower._image_processor context_len = getattr(config, 'max_sequence_length', 2048) model.to(device).eval() if image is not None: prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt conv = conv_phi_v0.copy() conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() if image is not None: # print('loading image...') image = load_image(image) # print('load image over') image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16) input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .cuda() ) # Generate stime = time.time() # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # keywords = [stop_str] # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) # print('start inference...') with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, num_beams=num_beams, pad_token_id=tokenizer.pad_token_id, max_new_tokens=max_new_tokens, use_cache=True, # stopping_criteria=[stopping_criteria], ) # print('inference over') generation_time = time.time() - stime outputs = tokenizer.batch_decode( output_ids, skip_special_tokens=True )[0] # outputs = outputs.strip() # if outputs.endswith(stop_str): # outputs = outputs[: -len(stop_str)] outputs = outputs.strip() return outputs, generation_time def tinyllava_elm_generate_parser(): """Argument Parser""" class KwargsParser(argparse.Action): """Parser action class to parse kwargs of form key=value""" def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for val in values: if '=' not in val: raise ValueError( ( 'Argument parsing error, kwargs are expected in' ' the form of key=value.' ) ) kwarg_k, kwarg_v = val.split('=') try: converted_v = int(kwarg_v) except ValueError: try: converted_v = float(kwarg_v) except ValueError: converted_v = kwarg_v getattr(namespace, self.dest)[kwarg_k] = converted_v parser = argparse.ArgumentParser('TinyLLaVA-OpenELM Generate Module') parser.add_argument( '--model', dest='model', help='Path to the hf converted model.', required=True, type=str, ) parser.add_argument( '--prompt', dest='prompt', help='Prompt for LLM call.', default='', type=str, ) parser.add_argument( '--device', dest='device', help='Device used for inference.', type=str, ) parser.add_argument("--image", type=str, default=None) parser.add_argument("--temperature", type=float, default=0) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--max_new_tokens", type=int, default=512) return parser.parse_args() if __name__ == '__main__': args = tinyllava_elm_generate_parser() prompt = args.prompt model = TinyLlavaForConditionalGeneration.from_pretrained(args.model) output_text, genertaion_time = generate( prompt=prompt, image=args.image, model=args.model, device=args.device, max_new_tokens = args.max_new_tokens, num_beams = args.num_beams, top_p=args.top_p, temperature=args.temperature ) print_txt = ( f'\r\n{"=" * os.get_terminal_size().columns}\r\n' '\033[1m Prompt + Generated Output\033[0m\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' f'{output_text}\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' '\r\nGeneration took' f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' 'seconds.\r\n' ) print(print_txt)