Spaces:
Runtime error
Runtime error
import os | |
import random | |
import gradio as gr | |
from huggingface_hub import login, hf_hub_download | |
import spaces | |
import torch | |
from diffusers import DiffusionPipeline | |
import hashlib | |
import pickle | |
import yaml | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# Load config file | |
with open('config.yaml', 'r') as file: | |
config = yaml.safe_load(file) | |
# Authenticate using the token stored in Hugging Face Spaces secrets | |
if 'HF_TOKEN' in os.environ: | |
login(token=os.environ['HF_TOKEN']) | |
logging.info("Successfully logged in with HF_TOKEN") | |
else: | |
logging.warning("HF_TOKEN not found in environment variables. Some functionality may be limited.") | |
# Correctly access the config values | |
process_config = config['config']['process'][0] # Assuming the first process is the one we want | |
base_model = "black-forest-labs/FLUX.1-dev" | |
lora_model = "sagar007/sagar_flux" # This isn't in the config, so we're keeping it as is | |
trigger_word = process_config['trigger_word'] | |
logging.info(f"Base model: {base_model}") | |
logging.info(f"LoRA model: {lora_model}") | |
logging.info(f"Trigger word: {trigger_word}") | |
# Global variables | |
pipe = None | |
cache = {} | |
CACHE_FILE = "image_cache.pkl" | |
# Example prompts | |
example_prompts = [ | |
"Photos of sagar as superman flying in the sky, cape billowing in the wind, sagar", | |
"Professional photo of sagar for LinkedIn headshot, DSLR quality, neutral background, sagar", | |
"Sagar as an astronaut exploring a distant alien planet, vibrant colors, sagar", | |
"Sagar hiking in a lush green forest, sunlight filtering through the trees, sagar", | |
"Sagar as a wizard casting a spell, magical energy swirling around, sagar", | |
"Sagar scoring a goal in a dramatic soccer match, stadium lights shining, sagar", | |
"Sagar as a Roman emperor addressing a crowd, wearing a toga and laurel wreath, sagar" | |
] | |
def initialize_model(): | |
global pipe | |
if pipe is None: | |
try: | |
logging.info(f"Attempting to load model: {base_model}") | |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16, use_safetensors=True) | |
logging.info("Moving model to CUDA...") | |
pipe = pipe.to("cuda") | |
logging.info(f"Successfully loaded model: {base_model}") | |
except Exception as e: | |
logging.error(f"Error loading model {base_model}: {str(e)}") | |
raise | |
def load_cache(): | |
global cache | |
if os.path.exists(CACHE_FILE): | |
with open(CACHE_FILE, 'rb') as f: | |
cache = pickle.load(f) | |
logging.info(f"Loaded {len(cache)} cached images") | |
def save_cache(): | |
with open(CACHE_FILE, 'wb') as f: | |
pickle.dump(cache, f) | |
logging.info(f"Saved {len(cache)} cached images") | |
def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale): | |
return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest() | |
def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale): | |
global pipe, cache | |
if randomize_seed: | |
seed = random.randint(0, 2**32-1) | |
cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale) | |
if cache_key in cache: | |
logging.info("Using cached image") | |
return cache[cache_key], seed | |
try: | |
logging.info(f"Starting run_lora with prompt: {prompt}") | |
if pipe is None: | |
logging.info("Initializing model...") | |
initialize_model() | |
logging.info(f"Using seed: {seed}") | |
generator = torch.Generator(device="cuda").manual_seed(seed) | |
full_prompt = f"{prompt} {trigger_word}" | |
logging.info(f"Full prompt: {full_prompt}") | |
logging.info("Starting image generation...") | |
image = pipe( | |
prompt=full_prompt, | |
num_inference_steps=steps, | |
guidance_scale=cfg_scale, | |
width=width, | |
height=height, | |
generator=generator, | |
).images[0] | |
logging.info("Image generation completed successfully") | |
# Cache the generated image | |
cache[cache_key] = image | |
save_cache() | |
return image, seed | |
except Exception as e: | |
logging.error(f"Error during generation: {str(e)}") | |
import traceback | |
logging.error(traceback.format_exc()) | |
return None, seed | |
def update_prompt(example): | |
return example | |
# Load cache at startup | |
load_cache() | |
# Pre-generate and cache example images | |
def cache_example_images(): | |
for prompt in example_prompts: | |
run_lora(prompt, process_config['sample']['guidance_scale'], process_config['sample']['sample_steps'], | |
process_config['sample']['walk_seed'], process_config['sample']['seed'], | |
process_config['sample']['width'], process_config['sample']['height'], 0.75) | |
# Gradio interface setup | |
with gr.Blocks() as app: | |
gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox(label="Prompt") | |
example_dropdown = gr.Dropdown(choices=example_prompts, label="Example Prompts") | |
run_button = gr.Button("Generate") | |
with gr.Column(): | |
result = gr.Image(label="Result") | |
with gr.Row(): | |
cfg_scale = gr.Slider(minimum=1, maximum=20, value=process_config['sample']['guidance_scale'], step=0.1, label="CFG Scale") | |
steps = gr.Slider(minimum=1, maximum=100, value=process_config['sample']['sample_steps'], step=1, label="Steps") | |
with gr.Row(): | |
width = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['width'], step=64, label="Width") | |
height = gr.Slider(minimum=128, maximum=1024, value=process_config['sample']['height'], step=64, label="Height") | |
with gr.Row(): | |
seed = gr.Number(label="Seed", value=process_config['sample']['seed'], precision=0) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=process_config['sample']['walk_seed']) | |
lora_scale = gr.Slider(minimum=0, maximum=1, value=0.75, step=0.01, label="LoRA Scale") | |
example_dropdown.change(update_prompt, inputs=[example_dropdown], outputs=[prompt]) | |
run_button.click( | |
run_lora, | |
inputs=[prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale], | |
outputs=[result, seed] | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
logging.info("Starting the Gradio app...") | |
logging.info("Pre-generating example images...") | |
cache_example_images() | |
app.launch(share=True) | |
logging.info("Gradio app launched successfully") |