fantos commited on
Commit
0b34ea3
·
verified ·
1 Parent(s): 5e92500

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -45
app.py CHANGED
@@ -14,7 +14,8 @@ from PIL import Image
14
 
15
  # Setup and initialization code
16
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
17
- gallery_path = path.join(path.dirname(path.abspath(__file__)), "gallery")
 
18
  os.environ["TRANSFORMERS_CACHE"] = cache_path
19
  os.environ["HF_HUB_CACHE"] = cache_path
20
  os.environ["HF_HOME"] = cache_path
@@ -110,19 +111,6 @@ footer {display: none !important}
110
  border-radius: 4px !important;
111
  transition: transform 0.2s;
112
  }
113
- /* Force gallery items to maintain aspect ratio */
114
- .gallery-item {
115
- width: 100% !important;
116
- aspect-ratio: 1 !important;
117
- overflow: hidden !important;
118
- }
119
- .gallery-item img {
120
- width: 100% !important;
121
- height: 100% !important;
122
- object-fit: cover !important;
123
- border-radius: 4px;
124
- transition: transform 0.2s;
125
- }
126
  .gallery-item img:hover {
127
  transform: scale(1.05);
128
  }
@@ -161,30 +149,60 @@ footer {display: none !important}
161
 
162
  def save_image(image):
163
  """Save the generated image and return the path"""
 
 
 
 
164
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
165
- filename = f"generated_{timestamp}.png"
 
166
  filepath = os.path.join(gallery_path, filename)
167
 
168
- if isinstance(image, Image.Image):
169
- image.save(filepath)
170
- else:
171
- image = Image.fromarray(image)
172
- image.save(filepath)
173
-
174
- return filepath
 
 
 
 
 
 
 
 
 
 
175
 
176
  def load_gallery():
177
  """Load all images from the gallery directory"""
178
- image_files = [f for f in os.listdir(gallery_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
179
- image_files.sort(reverse=True) # Most recent first
180
- return [os.path.join(gallery_path, f) for f in image_files]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
 
182
  # Create Gradio interface
183
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
184
  gr.HTML('<div class="title">AI Image Generator</div>')
185
  gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
186
 
187
-
188
  with gr.Row():
189
  with gr.Column(scale=3):
190
  prompt = gr.Textbox(
@@ -269,12 +287,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
269
  """)
270
 
271
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
272
- # Current generated image
273
  output = gr.Image(
274
  label="Generated Image",
275
  elem_id="output-image",
276
  elem_classes=["output-image", "fixed-width"]
277
  )
 
 
278
  gallery = gr.Gallery(
279
  label="Generated Images Gallery",
280
  show_label=True,
@@ -285,9 +305,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
285
  object_fit="cover",
286
  elem_classes=["gallery-container", "fixed-width"]
287
  )
288
-
289
-
290
-
291
 
292
  # Load existing gallery images on startup
293
  gallery.value = load_gallery()
@@ -296,21 +313,27 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
296
  def process_and_save_image(height, width, steps, scales, prompt, seed):
297
  global pipe
298
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
299
- generated_image = pipe(
300
- prompt=[prompt],
301
- generator=torch.Generator().manual_seed(int(seed)),
302
- num_inference_steps=int(steps),
303
- guidance_scale=float(scales),
304
- height=int(height),
305
- width=int(width),
306
- max_sequence_length=256
307
- ).images[0]
308
-
309
- # Save the generated image
310
- save_image(generated_image)
311
-
312
- # Return both the generated image and updated gallery
313
- return generated_image, load_gallery()
 
 
 
 
 
 
314
 
315
  # Connect the generation button to both the image output and gallery update
316
  def update_seed():
 
14
 
15
  # Setup and initialization code
16
  cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
17
+ # Change gallery path to user's home directory for persistence
18
+ gallery_path = path.join(os.path.expanduser("~"), "ai_generated_images")
19
  os.environ["TRANSFORMERS_CACHE"] = cache_path
20
  os.environ["HF_HUB_CACHE"] = cache_path
21
  os.environ["HF_HOME"] = cache_path
 
111
  border-radius: 4px !important;
112
  transition: transform 0.2s;
113
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  .gallery-item img:hover {
115
  transform: scale(1.05);
116
  }
 
149
 
150
  def save_image(image):
151
  """Save the generated image and return the path"""
152
+ # Ensure gallery directory exists
153
+ os.makedirs(gallery_path, exist_ok=True)
154
+
155
+ # Generate unique filename with timestamp and random suffix
156
  timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
157
+ random_suffix = os.urandom(4).hex()
158
+ filename = f"generated_{timestamp}_{random_suffix}.png"
159
  filepath = os.path.join(gallery_path, filename)
160
 
161
+ try:
162
+ if isinstance(image, Image.Image):
163
+ # Save with maximum quality
164
+ image.save(filepath, "PNG", quality=100)
165
+ else:
166
+ image = Image.fromarray(image)
167
+ image.save(filepath, "PNG", quality=100)
168
+
169
+ # Verify the file was saved correctly
170
+ if not os.path.exists(filepath):
171
+ print(f"Warning: Failed to verify saved image at {filepath}")
172
+ return None
173
+
174
+ return filepath
175
+ except Exception as e:
176
+ print(f"Error saving image: {str(e)}")
177
+ return None
178
 
179
  def load_gallery():
180
  """Load all images from the gallery directory"""
181
+ try:
182
+ # Ensure gallery directory exists
183
+ os.makedirs(gallery_path, exist_ok=True)
184
+
185
+ # Get all image files and sort by modification time
186
+ image_files = []
187
+ for f in os.listdir(gallery_path):
188
+ if f.lower().endswith(('.png', '.jpg', '.jpeg')):
189
+ full_path = os.path.join(gallery_path, f)
190
+ image_files.append((full_path, os.path.getmtime(full_path)))
191
+
192
+ # Sort by modification time (newest first)
193
+ image_files.sort(key=lambda x: x[1], reverse=True)
194
+
195
+ # Return only the file paths
196
+ return [f[0] for f in image_files]
197
+ except Exception as e:
198
+ print(f"Error loading gallery: {str(e)}")
199
+ return []
200
 
201
  # Create Gradio interface
202
  with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
203
  gr.HTML('<div class="title">AI Image Generator</div>')
204
  gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
205
 
 
206
  with gr.Row():
207
  with gr.Column(scale=3):
208
  prompt = gr.Textbox(
 
287
  """)
288
 
289
  with gr.Column(scale=4, elem_classes=["fixed-width"]):
290
+ # Current generated image
291
  output = gr.Image(
292
  label="Generated Image",
293
  elem_id="output-image",
294
  elem_classes=["output-image", "fixed-width"]
295
  )
296
+
297
+ # Gallery of generated images
298
  gallery = gr.Gallery(
299
  label="Generated Images Gallery",
300
  show_label=True,
 
305
  object_fit="cover",
306
  elem_classes=["gallery-container", "fixed-width"]
307
  )
 
 
 
308
 
309
  # Load existing gallery images on startup
310
  gallery.value = load_gallery()
 
313
  def process_and_save_image(height, width, steps, scales, prompt, seed):
314
  global pipe
315
  with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
316
+ try:
317
+ generated_image = pipe(
318
+ prompt=[prompt],
319
+ generator=torch.Generator().manual_seed(int(seed)),
320
+ num_inference_steps=int(steps),
321
+ guidance_scale=float(scales),
322
+ height=int(height),
323
+ width=int(width),
324
+ max_sequence_length=256
325
+ ).images[0]
326
+
327
+ # Save the generated image
328
+ saved_path = save_image(generated_image)
329
+ if saved_path is None:
330
+ print("Warning: Failed to save generated image")
331
+
332
+ # Return both the generated image and updated gallery
333
+ return generated_image, load_gallery()
334
+ except Exception as e:
335
+ print(f"Error in image generation: {str(e)}")
336
+ return None, load_gallery()
337
 
338
  # Connect the generation button to both the image output and gallery update
339
  def update_seed():