maxcushion / app.py
colt12's picture
Update app.py
4043227 verified
raw
history blame
1.94 kB
import io
from PIL import Image
import torch
from flask import Flask, request, jsonify
from transformers import VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
app = Flask(__name__)
# Load the model and processors
model_name = "colt12/maxcushion"
try:
print("Loading model...")
model = VisionEncoderDecoderModel.from_pretrained(model_name)
print("Model loaded successfully.")
print("Loading feature extractor...")
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
print("Feature extractor loaded successfully.")
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
print("Tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or processors: {str(e)}")
raise
def predict(image_bytes):
# Open the image using PIL
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
# Preprocess the image
pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
# Generate the caption
with torch.no_grad():
output_ids = model.generate(pixel_values, max_length=50, num_return_sequences=1)
generated_caption = tokenizer.decode(output_ids[0], skip_special_tokens=True)
return generated_caption
@app.route('/', methods=['GET'])
def home():
return "Welcome to the Image Captioning API. Use the /predict endpoint to generate captions for images."
@app.route('/predict', methods=['POST'])
def run():
if 'image' not in request.files:
return jsonify({"error": "No image provided"}), 400
image_file = request.files['image']
try:
image_bytes = image_file.read()
result = predict(image_bytes)
return jsonify({"caption": result})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)