Saad0KH's picture
Update app.py
e2e8031 verified
raw
history blame
3.43 kB
from flask import Flask, request, jsonify, send_file
from PIL import Image
import base64
import io
import random
import uuid
import numpy as np
import spaces
import torch
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
# Create a Flask instance
app = Flask(__name__)
def clear_gpu_memory():
torch.cuda.empty_cache()
torch.cuda.synchronize()
# Initialize model only once
pipe = None
def load_model():
global pipe
if pipe is None:
pipe = StableDiffusionXLPipeline.from_pretrained(
"fluently/Fluently-XL-v2",
torch_dtype=torch.float16,
use_safetensors=True,
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle")
pipe.set_adapters("dalle")
if torch.cuda.is_available():
pipe.to("cuda")
# Load the model during app initialization
load_model()
def save_image(img):
unique_name = str(uuid.uuid4()) + ".png"
img.save(unique_name)
return unique_name
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
MAX_SEED = np.iinfo(np.int32).max
@spaces.GPU(enable_queue=True)
def generate(
prompt: str,
negative_prompt: str = "",
use_negative_prompt: bool = False,
seed: int = 0,
num_images_per_prompt: int = 1,
width: int = 512, # Reduced image width
height: int = 512, # Reduced image height
guidance_scale: float = 3,
randomize_seed: bool = False,
):
seed = int(randomize_seed_fn(seed, randomize_seed))
if not use_negative_prompt:
negative_prompt = "" # type: ignore
images = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
guidance_scale=guidance_scale,
num_inference_steps=25,
num_images_per_prompt=num_images_per_prompt,
cross_attention_kwargs={"scale": 0.65},
output_type="pil",
).images
image_paths = [save_image(img) for img in images]
print(image_paths)
return image_paths, seed
@app.get("/")
def root():
return "Welcome to the Fashion Outfit"
@app.route('/api/get_image/<image_id>', methods=['GET'])
def get_image(image_id):
try:
return send_file(image_id, mimetype='image/png')
except FileNotFoundError:
return jsonify({'error': 'Image not found'}), 404
@app.route('/api/run', methods=['POST'])
def run():
data = request.json
print(data)
prompt = data['prompt']
negative_prompt = data['negative_prompt']
use_negative_prompt = data['use_negative_prompt']
guidance_scale = data['guidance_scale']
randomize_seed = data['randomize_seed']
num_images_per_prompt = data['num_images_per_prompt']
width = data['width'] if 'width' in data else 512 # Default width
height = data['height'] if 'height' in data else 512 # Default height
clear_gpu_memory()
result = generate(
prompt,
negative_prompt,
use_negative_prompt,
0,
num_images_per_prompt,
width,
height,
guidance_scale,
randomize_seed
)
return jsonify({'out': result})
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)