|
python |
|
import io |
|
import base64 |
|
import torch |
|
from PIL import Image |
|
from flask import Flask, request, jsonify |
|
from diffusers import StableDiffusionPipeline |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
model_name = "colt12/maxcushion" |
|
try: |
|
print("Loading model...") |
|
pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16) |
|
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
|
print("Model loaded successfully.") |
|
except Exception as e: |
|
print(f"Error loading model: {str(e)}") |
|
raise |
|
|
|
def generate_image(prompt): |
|
with torch.no_grad(): |
|
image = pipe(prompt).images[0] |
|
|
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
return image_base64 |
|
|
|
@app.route('/', methods=['GET']) |
|
def home(): |
|
return "Welcome to the Image Generation API. Use the /generate endpoint to generate images from prompts." |
|
|
|
@app.route('/generate', methods=['POST']) |
|
def run(): |
|
if 'prompt' not in request.json: |
|
return jsonify({"error": "No prompt provided"}), 400 |
|
|
|
prompt = request.json['prompt'] |
|
|
|
try: |
|
result = generate_image(prompt) |
|
return jsonify({"image": result}) |
|
except Exception as e: |
|
return jsonify({"error": str(e)}), 500 |
|
|
|
if __name__ == "__main__": |
|
app.run(host='0.0.0.0', port=5000) |