|
import io |
|
from PIL import Image |
|
import torch |
|
from flask import Flask, request, jsonify |
|
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer |
|
app = Flask(__name__) |
|
|
|
model_name = "colt12/maxcushion" |
|
try: |
|
print("Loading model...") |
|
model = VisionEncoderDecoderModel.from_pretrained(model_name) |
|
print("Model loaded successfully.") |
|
|
|
print("Loading feature extractor...") |
|
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k") |
|
print("Feature extractor loaded successfully.") |
|
|
|
print("Loading tokenizer...") |
|
tokenizer = AutoTokenizer.from_pretrained("gpt2") |
|
print("Tokenizer loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading model or processors: {str(e)}") |
|
raise |
|
def predict(image_bytes): |
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
|
|
|
with torch.no_grad(): |
|
output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1) |
|
generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True) |
|
|
|
return generated_caption |
|
@app.route('/', methods=['GET']) |
|
def home(): |
|
return "Welcome to the Image Captioning API. Use the /predict endpoint to generate captions for images." |
|
@app.route('/predict', methods=['POST']) |
|
def run(): |
|
if 'image' not in request.files: |
|
return jsonify({"error": "No image provided"}), 400 |
|
|
|
image_file = request.files['image'] |
|
|
|
try: |
|
image_bytes = image_file.read() |
|
result = predict(image_bytes) |
|
return jsonify({"caption": result}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=5000) |