|
import os |
|
import io |
|
import torch |
|
import PIL |
|
from PIL import Image |
|
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration |
|
import bitsandbytes |
|
import accelerate |
|
from my_model.config import captioning_config as config |
|
from my_model.utilities.gen_utilities import free_gpu_resources |
|
|
|
|
|
class ImageCaptioningModel: |
|
def __init__(self): |
|
self.model_type = config.MODEL_TYPE |
|
self.processor = None |
|
self.model = None |
|
self.prompt = config.PROMPT |
|
self.max_image_size = config.MAX_IMAGE_SIZE |
|
self.min_length = config.MIN_LENGTH |
|
self.max_new_tokens = config.MAX_NEW_TOKENS |
|
self.model_path = config.MODEL_PATH |
|
self.device_map = config.DEVICE_MAP |
|
self.torch_dtype = config.TORCH_DTYPE |
|
self.load_in_8bit = config.LOAD_IN_8BIT |
|
self.load_in_4bit = config.LOAD_IN_4BIT |
|
self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE |
|
self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS |
|
|
|
|
|
|
|
def load_model(self): |
|
|
|
if self.load_in_4bit and self.load_in_8bit: |
|
self.load_in_4bit = False |
|
|
|
if self.model_type == 'i_blip': |
|
self.processor = InstructBlipProcessor.from_pretrained(self.model_path, |
|
load_in_8bit=self.load_in_8bit, |
|
load_in_4bit=self.load_in_4bit, |
|
torch_dtype=self.torch_dtype, |
|
device_map=self.device_map |
|
) |
|
free_gpu_resources() |
|
self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path, |
|
load_in_8bit=self.load_in_8bit, |
|
load_in_4bit=self.load_in_4bit, |
|
torch_dtype=self.torch_dtype, |
|
low_cpu_mem_usage=self.low_cpu_mem_usage, |
|
device_map=self.device_map |
|
) |
|
|
|
free_gpu_resources() |
|
|
|
|
|
def resize_image(self, image, max_image_size=None): |
|
if max_image_size is None: |
|
max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024")) |
|
h, w = image.size |
|
scale = max_image_size / max(h, w) |
|
|
|
if scale < 1: |
|
new_w = int(w * scale) |
|
new_h = int(h * scale) |
|
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS) |
|
|
|
return image |
|
|
|
|
|
def generate_caption(self, image_path): |
|
free_gpu_resources() |
|
free_gpu_resources() |
|
if isinstance(image_path, str) or isinstance(image_path, io.IOBase): |
|
|
|
image = Image.open(image_path) |
|
|
|
elif isinstance(image_path, Image.Image): |
|
image = image_path |
|
|
|
image = self.resize_image(image) |
|
inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype) |
|
outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens) |
|
caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip() |
|
free_gpu_resources() |
|
free_gpu_resources() |
|
return caption |
|
|
|
def generate_captions_for_multiple_images(self, image_paths): |
|
|
|
return [self.generate_caption(image_path) for image_path in image_paths] |
|
|
|
|
|
def get_caption(img): |
|
captioner = ImageCaptioningModel() |
|
free_gpu_resources() |
|
captioner.load_model() |
|
free_gpu_resources() |
|
caption = captioner.generate_caption(img) |
|
free_gpu_resources() |
|
|
|
|
|
return caption |