maxcushion / app.py
colt12's picture
Update app.py
05e0399 verified
raw
history blame
1.44 kB
import io
from PIL import Image
import torch
from flask import Flask, request, jsonify
from diffusers import StableDiffusionPipeline
import base64
app = Flask(__name__)
# Load the model
model_name = "colt12/maxcushion"
try:
print("Loading model...")
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
print("Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {str(e)}")
raise
def generate_image(prompt):
# Generate the image
with torch.no_grad():
image = pipe(prompt).images[0]
# Convert image to base64
buffered = io.BytesIO()
image.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return image_base64
@app.route('/', methods=['GET'])
def home():
return "Welcome to the Image Generation API. Use the /generate endpoint to generate images from prompts."
@app.route('/generate', methods=['POST'])
def run():
if 'prompt' not in request.json:
return jsonify({"error": "No prompt provided"}), 400
prompt = request.json['prompt']
try:
result = generate_image(prompt)
return jsonify({"image": result})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host='0.0.0.0', port=5000)