File size: 4,240 Bytes
75a53d9 7553f0c 75a53d9 b434799 ab38e0e 5f4a46b 75a53d9 178416a 75a53d9 398c0e8 75a53d9 1089b06 75a53d9 609d6f1 75a53d9 f711846 75a53d9 609d6f1 75a53d9 609d6f1 75a53d9 609d6f1 aefece3 15d3f2d 54d3921 75a53d9 609d6f1 75a53d9 8f97cdd 609d6f1 8f97cdd 609d6f1 8f97cdd 609d6f1 8f97cdd 5f4a46b 609d6f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import io
import torch
import PIL
from PIL import Image
from transformers import InstructBlipProcessor, InstructBlipForConditionalGeneration
import bitsandbytes
import accelerate
from my_model.config import captioning_config as config
from my_model.utilities.gen_utilities import free_gpu_resources
class ImageCaptioningModel:
def __init__(self):
self.model_type = config.MODEL_TYPE
self.processor = None
self.model = None
self.prompt = config.PROMPT
self.max_image_size = config.MAX_IMAGE_SIZE
self.min_length = config.MIN_LENGTH
self.max_new_tokens = config.MAX_NEW_TOKENS
self.model_path = config.MODEL_PATH
self.device_map = config.DEVICE_MAP
self.torch_dtype = config.TORCH_DTYPE
self.load_in_8bit = config.LOAD_IN_8BIT
self.load_in_4bit = config.LOAD_IN_4BIT
self.low_cpu_mem_usage = config.LOW_CPU_MEM_USAGE
self.skip_secial_tokens = config.SKIP_SPECIAL_TOKENS
def load_model(self):
if self.load_in_4bit and self.load_in_8bit: # check if in case both set to True by mistake.
self.load_in_4bit = False
if self.model_type == 'i_blip':
self.processor = InstructBlipProcessor.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
load_in_4bit=self.load_in_4bit,
torch_dtype=self.torch_dtype,
device_map=self.device_map
)
free_gpu_resources()
self.model = InstructBlipForConditionalGeneration.from_pretrained(self.model_path,
load_in_8bit=self.load_in_8bit,
load_in_4bit=self.load_in_4bit,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
device_map=self.device_map
)
free_gpu_resources()
def resize_image(self, image, max_image_size=None):
if max_image_size is None:
max_image_size = int(os.getenv("MAX_IMAGE_SIZE", "1024"))
h, w = image.size
scale = max_image_size / max(h, w)
if scale < 1:
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h), resample=PIL.Image.Resampling.LANCZOS)
return image
def generate_caption(self, image_path):
free_gpu_resources()
free_gpu_resources()
if isinstance(image_path, str) or isinstance(image_path, io.IOBase):
# If it's a file path or file-like object, open it as a PIL Image
image = Image.open(image_path)
elif isinstance(image_path, Image.Image):
image = image_path
image = self.resize_image(image)
inputs = self.processor(image, self.prompt, return_tensors="pt").to("cuda", self.torch_dtype)
outputs = self.model.generate(**inputs, min_length=self.min_length, max_new_tokens=self.max_new_tokens)
caption = self.processor.decode(outputs[0], skip_special_tokens=self.skip_secial_tokens).strip()
free_gpu_resources()
free_gpu_resources()
return caption
def generate_captions_for_multiple_images(self, image_paths):
return [self.generate_caption(image_path) for image_path in image_paths]
def get_caption(img):
captioner = ImageCaptioningModel()
free_gpu_resources()
captioner.load_model()
free_gpu_resources()
caption = captioner.generate_caption(img)
free_gpu_resources()
return caption |