import io from PIL import Image import torch from transformers import VisionEncoderDecoderModel, ViTImageProcessor, AutoTokenizer # Load the model and processors model_name = "colt12/maxcushion" try: print("Loading model...") model = VisionEncoderDecoderModel.from_pretrained(model_name) print("Model loaded successfully.") print("Loading image processor...") image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k") print("Image processor loaded successfully.") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") print("Tokenizer loaded successfully.") except Exception as e: print(f"Error loading model or processors: {str(e)}") raise def predict(image_bytes): # Open the image using PIL image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess the image pixel_values = image_processor(images=image, return_tensors="pt").pixel_values # Generate the caption with torch.no_grad(): output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1) generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) return generated_caption def inference(inputs): # Check if the input is a file or raw bytes if "file" in inputs: image = inputs["file"] image_bytes = image.read() elif "bytes" in inputs: image_bytes = inputs["bytes"] else: raise ValueError("No valid input found. Expected 'file' or 'bytes'.") # Generate the caption result = predict(image_bytes) # Return the result return {"caption": result}