import io from PIL import Image import torch from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer, AutoConfig # Load the model and processors model_name = "colt12/maxcushion" try: print("Loading model configuration...") config = AutoConfig.from_pretrained(model_name) print("Loading model...") if isinstance(config, VisionEncoderDecoderModel): model = VisionEncoderDecoderModel.from_pretrained(model_name, config=config) else: # If the config is not for VisionEncoderDecoderModel, we might need to construct it manually encoder_config = AutoConfig.from_pretrained("google/vit-base-patch16-224-in21k") decoder_config = AutoConfig.from_pretrained("gpt2") model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( "google/vit-base-patch16-224-in21k", "gpt2", encoder_config=encoder_config, decoder_config=decoder_config ) model.load_state_dict(torch.load(f"{model_name}/pytorch_model.bin")) print("Model loaded successfully.") print("Loading image processor...") image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") print("Image processor loaded successfully.") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") print("Tokenizer loaded successfully.") except Exception as e: print(f"Error loading model or processors: {str(e)}") raise def predict(image_bytes): # Open the image using PIL image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess the image pixel_values = image_processor(images=image, return_tensors="pt").pixel_values # Generate the caption with torch.no_grad(): output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1) generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) return generated_caption def inference(inputs): # Check if the input is a file or raw bytes if "file" in inputs: image = inputs["file"] image_bytes = image.read() elif "bytes" in inputs: image_bytes = inputs["bytes"] else: raise ValueError("No valid input found. Expected 'file' or 'bytes'.") # Generate the caption result = predict(image_bytes) # Return the result return {"caption": result}