Spaces:
Running
on
Zero
Running
on
Zero
from flask import Flask, request, jsonify, send_file | |
from PIL import Image | |
import base64 | |
import io | |
import random | |
import uuid | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler | |
# Create a Flask instance | |
app = Flask(__name__) | |
def clear_gpu_memory(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
# Initialize model only once | |
pipe = None | |
def load_model(): | |
global pipe | |
if pipe is None: | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"fluently/Fluently-XL-v2", | |
torch_dtype=torch.float16, | |
use_safetensors=True, | |
) | |
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(pipe.scheduler.config) | |
pipe.load_lora_weights("ehristoforu/dalle-3-xl-v2", weight_name="dalle-3-xl-lora-v2.safetensors", adapter_name="dalle") | |
pipe.set_adapters("dalle") | |
if torch.cuda.is_available(): | |
pipe.to("cuda") | |
# Load the model during app initialization | |
load_model() | |
def save_image(img): | |
unique_name = str(uuid.uuid4()) + ".png" | |
img.save(unique_name) | |
return unique_name | |
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
return seed | |
MAX_SEED = np.iinfo(np.int32).max | |
def generate( | |
prompt: str, | |
negative_prompt: str = "", | |
use_negative_prompt: bool = False, | |
seed: int = 0, | |
num_images_per_prompt: int = 1, | |
width: int = 512, # Reduced image width | |
height: int = 512, # Reduced image height | |
guidance_scale: float = 3, | |
randomize_seed: bool = False, | |
): | |
seed = int(randomize_seed_fn(seed, randomize_seed)) | |
if not use_negative_prompt: | |
negative_prompt = "" # type: ignore | |
images = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_inference_steps=25, | |
num_images_per_prompt=num_images_per_prompt, | |
cross_attention_kwargs={"scale": 0.65}, | |
output_type="pil", | |
).images | |
image_paths = [save_image(img) for img in images] | |
print(image_paths) | |
return image_paths, seed | |
def root(): | |
return "Welcome to the Fashion Outfit" | |
def get_image(image_id): | |
try: | |
return send_file(image_id, mimetype='image/png') | |
except FileNotFoundError: | |
return jsonify({'error': 'Image not found'}), 404 | |
def run(): | |
data = request.json | |
print(data) | |
prompt = data['prompt'] | |
negative_prompt = data['negative_prompt'] | |
use_negative_prompt = data['use_negative_prompt'] | |
guidance_scale = data['guidance_scale'] | |
randomize_seed = data['randomize_seed'] | |
num_images_per_prompt = data['num_images_per_prompt'] | |
width = data['width'] if 'width' in data else 512 # Default width | |
height = data['height'] if 'height' in data else 512 # Default height | |
clear_gpu_memory() | |
result = generate( | |
prompt, | |
negative_prompt, | |
use_negative_prompt, | |
0, | |
num_images_per_prompt, | |
width, | |
height, | |
guidance_scale, | |
randomize_seed | |
) | |
return jsonify({'out': result}) | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |