maxcushion / app.py
colt12's picture
Update app.py
a31ed9b verified
raw
history blame
1.33 kB
from typing import Dict, List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import StableDiffusionXLPipeline
import torch
import base64
from io import BytesIO
app = FastAPI()
# Load the model
model_name = "colt12/maxcushion"
pipe = StableDiffusionXLPipeline.from_pretrained(model_name, torch_dtype=torch.float16, use_safetensors=True, variant="fp16")
pipe = pipe.to("cuda" if torch.cuda.is_available() else "cpu")
class Item(BaseModel):
prompt: str
negative_prompt: str = ""
num_inference_steps: int = 30
guidance_scale: float = 7.5
@app.post("/generate")
async def generate(item: Item) -> Dict[str, str]:
try:
# Generate the image
image = pipe(
prompt=item.prompt,
negative_prompt=item.negative_prompt,
num_inference_steps=item.num_inference_steps,
guidance_scale=item.guidance_scale
).images[0]
# Convert to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": image_base64}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
return {"message": "SDXL Image Generation API"}