sagar007 commited on
Commit
ad1f99d
1 Parent(s): 67e5720

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -23
app.py CHANGED
@@ -5,6 +5,8 @@ from huggingface_hub import login, hf_hub_download
5
  import spaces
6
  import torch
7
  from diffusers import DiffusionPipeline
 
 
8
 
9
  # Authenticate using the token stored in Hugging Face Spaces secrets
10
  if 'HF_TOKEN' in os.environ:
@@ -16,8 +18,10 @@ base_model = "black-forest-labs/FLUX.1-dev"
16
  lora_model = "sagar007/sagar_flux"
17
  trigger_word = "sagar"
18
 
19
- # Global variable for the pipeline
20
  pipe = None
 
 
21
 
22
  # Example prompts
23
  example_prompts = [
@@ -39,47 +43,50 @@ def initialize_model():
39
  print("Moving model to CUDA...")
40
  pipe = pipe.to("cuda")
41
  print(f"Successfully loaded base model: {base_model}")
42
-
43
- # Commenting out LoRA loading for now
44
- # print("Downloading LoRA weights...")
45
- # lora_path = download_lora_weights(lora_model)
46
- # if lora_path:
47
- # print("Loading LoRA weights...")
48
- # pipe.load_lora_weights(lora_path)
49
- # print("Successfully loaded LoRA weights")
50
- # else:
51
- # print("Failed to download LoRA weights. Continuing without LoRA.")
52
  except Exception as e:
53
  print(f"Error initializing model: {str(e)}")
54
  import traceback
55
  print(traceback.format_exc())
56
  raise
57
 
58
- def download_lora_weights(repo_id, filename="lora.safetensors"):
59
- try:
60
- lora_path = hf_hub_download(repo_id=repo_id, filename=filename)
61
- print(f"Successfully downloaded LoRA weights from {repo_id}")
62
- return lora_path
63
- except Exception as e:
64
- print(f"Error downloading LoRA weights: {str(e)}")
65
- return None
 
 
 
 
 
 
66
 
67
  @spaces.GPU(duration=80)
68
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
69
- global pipe
 
 
 
 
 
 
 
 
 
 
70
  try:
71
  print(f"Starting run_lora with prompt: {prompt}")
72
  if pipe is None:
73
  print("Initializing model...")
74
  initialize_model()
75
 
76
- if randomize_seed:
77
- seed = random.randint(0, 2**32-1)
78
  print(f"Using seed: {seed}")
79
 
80
  generator = torch.Generator(device="cuda").manual_seed(seed)
81
 
82
- # Include the trigger word in the prompt
83
  full_prompt = f"{prompt} {trigger_word}"
84
  print(f"Full prompt: {full_prompt}")
85
 
@@ -93,6 +100,11 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
93
  generator=generator,
94
  ).images[0]
95
  print("Image generation completed successfully")
 
 
 
 
 
96
  return image, seed
97
  except Exception as e:
98
  print(f"Error during generation: {str(e)}")
@@ -103,6 +115,14 @@ def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora
103
  def update_prompt(example):
104
  return example
105
 
 
 
 
 
 
 
 
 
106
  # Gradio interface setup
107
  with gr.Blocks() as app:
108
  gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)")
@@ -135,5 +155,7 @@ with gr.Blocks() as app:
135
  # Launch the app
136
  if __name__ == "__main__":
137
  print("Starting the Gradio app...")
 
 
138
  app.launch(share=True)
139
  print("Gradio app launched successfully")
 
5
  import spaces
6
  import torch
7
  from diffusers import DiffusionPipeline
8
+ import hashlib
9
+ import pickle
10
 
11
  # Authenticate using the token stored in Hugging Face Spaces secrets
12
  if 'HF_TOKEN' in os.environ:
 
18
  lora_model = "sagar007/sagar_flux"
19
  trigger_word = "sagar"
20
 
21
+ # Global variables
22
  pipe = None
23
+ cache = {}
24
+ CACHE_FILE = "image_cache.pkl"
25
 
26
  # Example prompts
27
  example_prompts = [
 
43
  print("Moving model to CUDA...")
44
  pipe = pipe.to("cuda")
45
  print(f"Successfully loaded base model: {base_model}")
 
 
 
 
 
 
 
 
 
 
46
  except Exception as e:
47
  print(f"Error initializing model: {str(e)}")
48
  import traceback
49
  print(traceback.format_exc())
50
  raise
51
 
52
+ def load_cache():
53
+ global cache
54
+ if os.path.exists(CACHE_FILE):
55
+ with open(CACHE_FILE, 'rb') as f:
56
+ cache = pickle.load(f)
57
+ print(f"Loaded {len(cache)} cached images")
58
+
59
+ def save_cache():
60
+ with open(CACHE_FILE, 'wb') as f:
61
+ pickle.dump(cache, f)
62
+ print(f"Saved {len(cache)} cached images")
63
+
64
+ def get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale):
65
+ return hashlib.md5(f"{prompt}{cfg_scale}{steps}{seed}{width}{height}{lora_scale}".encode()).hexdigest()
66
 
67
  @spaces.GPU(duration=80)
68
  def run_lora(prompt, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale):
69
+ global pipe, cache
70
+
71
+ if randomize_seed:
72
+ seed = random.randint(0, 2**32-1)
73
+
74
+ cache_key = get_cache_key(prompt, cfg_scale, steps, seed, width, height, lora_scale)
75
+
76
+ if cache_key in cache:
77
+ print("Using cached image")
78
+ return cache[cache_key], seed
79
+
80
  try:
81
  print(f"Starting run_lora with prompt: {prompt}")
82
  if pipe is None:
83
  print("Initializing model...")
84
  initialize_model()
85
 
 
 
86
  print(f"Using seed: {seed}")
87
 
88
  generator = torch.Generator(device="cuda").manual_seed(seed)
89
 
 
90
  full_prompt = f"{prompt} {trigger_word}"
91
  print(f"Full prompt: {full_prompt}")
92
 
 
100
  generator=generator,
101
  ).images[0]
102
  print("Image generation completed successfully")
103
+
104
+ # Cache the generated image
105
+ cache[cache_key] = image
106
+ save_cache()
107
+
108
  return image, seed
109
  except Exception as e:
110
  print(f"Error during generation: {str(e)}")
 
115
  def update_prompt(example):
116
  return example
117
 
118
+ # Load cache at startup
119
+ load_cache()
120
+
121
+ # Pre-generate and cache example images
122
+ def cache_example_images():
123
+ for prompt in example_prompts:
124
+ run_lora(prompt, 4, 20, False, 42, 1024, 1024, 0.75)
125
+
126
  # Gradio interface setup
127
  with gr.Blocks() as app:
128
  gr.Markdown("# Text-to-Image Generation with FLUX (ZeroGPU)")
 
155
  # Launch the app
156
  if __name__ == "__main__":
157
  print("Starting the Gradio app...")
158
+ print("Pre-generating example images...")
159
+ cache_example_images()
160
  app.launch(share=True)
161
  print("Gradio app launched successfully")