import io from PIL import Image import torch from flask import Flask, request, jsonify from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer app = Flask(__name__) # Load the model and processors model_name = "colt12/maxcushion" try: print("Loading model...") model = VisionEncoderDecoderModel.from_pretrained(model_name) print("Model loaded successfully.") print("Loading feature extractor...") feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") print("Feature extractor loaded successfully.") print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained("gpt2") print("Tokenizer loaded successfully.") except Exception as e: print(f"Error loading model or processors: {str(e)}") raise def predict(image_bytes): # Open the image using PIL image = Image.open(io.BytesIO(image_bytes)).convert("RGB") # Preprocess the image pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values # Generate the caption with torch.no_grad(): output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1) generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) return generated_caption @app.route('/', methods=['GET']) def home(): return "Welcome to the Image Captioning API. Use the /predict endpoint to generate captions for images." @app.route('/predict', methods=['POST']) def run(): if 'image' not in request.files: return jsonify({"error": "No image provided"}), 400 image_file = request.files['image'] try: image_bytes = image_file.read() result = predict(image_bytes) return jsonify({"caption": result}) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == "__main__": app.run(host='0.0.0.0', port=5000)