Spaces:
Sleeping
Sleeping
import os | |
import io | |
import torch | |
import PIL | |
from PIL import Image | |
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration | |
import bitsandbytes | |
import accelerate | |
from my_model.config import captioning_config as config | |
from my_model.utilities.gen_utilities import free_gpu_resources | |
class ImageCaptioningModel: | |
def __init__(self): | |
self.model_type = config.MODEL_TYPE | |
self.processor = None | |
self.model = None | |
self.prompt = config.PROMPT | |
self.max_image_size = config.MAX_IMAGE_SIZE | |
self.min_length = config.MIN_LENGTH | |
self.max_new_tokens = config.MAX_NEW_TOKENS | |
self.model_path = config.MODEL_PATH | |
self.device_map = config.DEVICE_MAP | |
self.torch_dtype = config.TORCH_DTYPE | |
self.load_in_8bit = config.LOAD_IN_8BIT | |
self.load_in_4bit = config.LOAD_IN_4BIT | |
self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE | |
self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS | |
def load_model(self): | |
if self.load_in_4bit and self.load_in_8bit: # check if in case both set to True by mistake. | |
self.load_in_4bit = False | |
if self.model_type == 'i_blip': | |
self.processor = InstructBlipProcessor.from_pretrained(self.model_path, | |
load_in_8bit=self.load_in_8bit, | |
load_in_4bit=self.load_in_4bit, | |
torch_dtype=self.torch_dtype, | |
device_map=self.device_map | |
) | |
self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path, | |
load_in_8bit=self.load_in_8bit, | |
load_in_4bit=self.load_in_4bit, | |
torch_dtype=self.torch_dtype, | |
low_cpu_mem_usage=self.low_cpu_mem_usage, | |
device_map=self.device_map | |
) | |
def resize_image(self, image, max_image_size=None): | |
if max_image_size is None: | |
max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024")) | |
h, w = image.size | |
scale = max_image_size / max(h, w) | |
if scale < 1: | |
new_w = int(w * scale) | |
new_h = int(h * scale) | |
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) | |
return image | |
def generate_caption(self, image_path): | |
if isinstance(image_path, str) or isinstance(image_path, io.IOBase): | |
# If it's a file path or file-like object, open it as a PIL Image | |
image = Image.open(image_path) | |
elif isinstance(image_path, Image.Image): | |
image = image_path | |
image = self.resize_image(image) | |
inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype) | |
outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens) | |
caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip() | |
return caption | |
def generate_captions_for_multiple_images(self, image_paths): | |
return [self.generate_caption(image_path) for image_path in image_paths] | |
def get_caption(img): | |
captioner = ImageCaptioningModel() | |
captioner.load_model() | |
caption = captioner.generate_caption(img) | |
return caption | |
if __name__ == "__main__": | |
pass | |