Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -14,7 +14,8 @@ from PIL import Image
|
|
14 |
|
15 |
# Setup and initialization code
|
16 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
17 |
-
|
|
|
18 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
19 |
os.environ["HF_HUB_CACHE"] = cache_path
|
20 |
os.environ["HF_HOME"] = cache_path
|
@@ -110,19 +111,6 @@ footer {display: none !important}
|
|
110 |
border-radius: 4px !important;
|
111 |
transition: transform 0.2s;
|
112 |
}
|
113 |
-
/* Force gallery items to maintain aspect ratio */
|
114 |
-
.gallery-item {
|
115 |
-
width: 100% !important;
|
116 |
-
aspect-ratio: 1 !important;
|
117 |
-
overflow: hidden !important;
|
118 |
-
}
|
119 |
-
.gallery-item img {
|
120 |
-
width: 100% !important;
|
121 |
-
height: 100% !important;
|
122 |
-
object-fit: cover !important;
|
123 |
-
border-radius: 4px;
|
124 |
-
transition: transform 0.2s;
|
125 |
-
}
|
126 |
.gallery-item img:hover {
|
127 |
transform: scale(1.05);
|
128 |
}
|
@@ -161,30 +149,60 @@ footer {display: none !important}
|
|
161 |
|
162 |
def save_image(image):
|
163 |
"""Save the generated image and return the path"""
|
|
|
|
|
|
|
|
|
164 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
165 |
-
|
|
|
166 |
filepath = os.path.join(gallery_path, filename)
|
167 |
|
168 |
-
|
169 |
-
image.
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
175 |
|
176 |
def load_gallery():
|
177 |
"""Load all images from the gallery directory"""
|
178 |
-
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
# Create Gradio interface
|
183 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
184 |
gr.HTML('<div class="title">AI Image Generator</div>')
|
185 |
gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
|
186 |
|
187 |
-
|
188 |
with gr.Row():
|
189 |
with gr.Column(scale=3):
|
190 |
prompt = gr.Textbox(
|
@@ -269,12 +287,14 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
269 |
""")
|
270 |
|
271 |
with gr.Column(scale=4, elem_classes=["fixed-width"]):
|
272 |
-
|
273 |
output = gr.Image(
|
274 |
label="Generated Image",
|
275 |
elem_id="output-image",
|
276 |
elem_classes=["output-image", "fixed-width"]
|
277 |
)
|
|
|
|
|
278 |
gallery = gr.Gallery(
|
279 |
label="Generated Images Gallery",
|
280 |
show_label=True,
|
@@ -285,9 +305,6 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
285 |
object_fit="cover",
|
286 |
elem_classes=["gallery-container", "fixed-width"]
|
287 |
)
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
|
292 |
# Load existing gallery images on startup
|
293 |
gallery.value = load_gallery()
|
@@ -296,21 +313,27 @@ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
|
296 |
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
297 |
global pipe
|
298 |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
# Connect the generation button to both the image output and gallery update
|
316 |
def update_seed():
|
|
|
14 |
|
15 |
# Setup and initialization code
|
16 |
cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
|
17 |
+
# Change gallery path to user's home directory for persistence
|
18 |
+
gallery_path = path.join(os.path.expanduser("~"), "ai_generated_images")
|
19 |
os.environ["TRANSFORMERS_CACHE"] = cache_path
|
20 |
os.environ["HF_HUB_CACHE"] = cache_path
|
21 |
os.environ["HF_HOME"] = cache_path
|
|
|
111 |
border-radius: 4px !important;
|
112 |
transition: transform 0.2s;
|
113 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
.gallery-item img:hover {
|
115 |
transform: scale(1.05);
|
116 |
}
|
|
|
149 |
|
150 |
def save_image(image):
|
151 |
"""Save the generated image and return the path"""
|
152 |
+
# Ensure gallery directory exists
|
153 |
+
os.makedirs(gallery_path, exist_ok=True)
|
154 |
+
|
155 |
+
# Generate unique filename with timestamp and random suffix
|
156 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
157 |
+
random_suffix = os.urandom(4).hex()
|
158 |
+
filename = f"generated_{timestamp}_{random_suffix}.png"
|
159 |
filepath = os.path.join(gallery_path, filename)
|
160 |
|
161 |
+
try:
|
162 |
+
if isinstance(image, Image.Image):
|
163 |
+
# Save with maximum quality
|
164 |
+
image.save(filepath, "PNG", quality=100)
|
165 |
+
else:
|
166 |
+
image = Image.fromarray(image)
|
167 |
+
image.save(filepath, "PNG", quality=100)
|
168 |
+
|
169 |
+
# Verify the file was saved correctly
|
170 |
+
if not os.path.exists(filepath):
|
171 |
+
print(f"Warning: Failed to verify saved image at {filepath}")
|
172 |
+
return None
|
173 |
+
|
174 |
+
return filepath
|
175 |
+
except Exception as e:
|
176 |
+
print(f"Error saving image: {str(e)}")
|
177 |
+
return None
|
178 |
|
179 |
def load_gallery():
|
180 |
"""Load all images from the gallery directory"""
|
181 |
+
try:
|
182 |
+
# Ensure gallery directory exists
|
183 |
+
os.makedirs(gallery_path, exist_ok=True)
|
184 |
+
|
185 |
+
# Get all image files and sort by modification time
|
186 |
+
image_files = []
|
187 |
+
for f in os.listdir(gallery_path):
|
188 |
+
if f.lower().endswith(('.png', '.jpg', '.jpeg')):
|
189 |
+
full_path = os.path.join(gallery_path, f)
|
190 |
+
image_files.append((full_path, os.path.getmtime(full_path)))
|
191 |
+
|
192 |
+
# Sort by modification time (newest first)
|
193 |
+
image_files.sort(key=lambda x: x[1], reverse=True)
|
194 |
+
|
195 |
+
# Return only the file paths
|
196 |
+
return [f[0] for f in image_files]
|
197 |
+
except Exception as e:
|
198 |
+
print(f"Error loading gallery: {str(e)}")
|
199 |
+
return []
|
200 |
|
201 |
# Create Gradio interface
|
202 |
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
203 |
gr.HTML('<div class="title">AI Image Generator</div>')
|
204 |
gr.HTML('<div style="text-align: center; margin-bottom: 2em; color: #666;">Create stunning images from your descriptions</div>')
|
205 |
|
|
|
206 |
with gr.Row():
|
207 |
with gr.Column(scale=3):
|
208 |
prompt = gr.Textbox(
|
|
|
287 |
""")
|
288 |
|
289 |
with gr.Column(scale=4, elem_classes=["fixed-width"]):
|
290 |
+
# Current generated image
|
291 |
output = gr.Image(
|
292 |
label="Generated Image",
|
293 |
elem_id="output-image",
|
294 |
elem_classes=["output-image", "fixed-width"]
|
295 |
)
|
296 |
+
|
297 |
+
# Gallery of generated images
|
298 |
gallery = gr.Gallery(
|
299 |
label="Generated Images Gallery",
|
300 |
show_label=True,
|
|
|
305 |
object_fit="cover",
|
306 |
elem_classes=["gallery-container", "fixed-width"]
|
307 |
)
|
|
|
|
|
|
|
308 |
|
309 |
# Load existing gallery images on startup
|
310 |
gallery.value = load_gallery()
|
|
|
313 |
def process_and_save_image(height, width, steps, scales, prompt, seed):
|
314 |
global pipe
|
315 |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
|
316 |
+
try:
|
317 |
+
generated_image = pipe(
|
318 |
+
prompt=[prompt],
|
319 |
+
generator=torch.Generator().manual_seed(int(seed)),
|
320 |
+
num_inference_steps=int(steps),
|
321 |
+
guidance_scale=float(scales),
|
322 |
+
height=int(height),
|
323 |
+
width=int(width),
|
324 |
+
max_sequence_length=256
|
325 |
+
).images[0]
|
326 |
+
|
327 |
+
# Save the generated image
|
328 |
+
saved_path = save_image(generated_image)
|
329 |
+
if saved_path is None:
|
330 |
+
print("Warning: Failed to save generated image")
|
331 |
+
|
332 |
+
# Return both the generated image and updated gallery
|
333 |
+
return generated_image, load_gallery()
|
334 |
+
except Exception as e:
|
335 |
+
print(f"Error in image generation: {str(e)}")
|
336 |
+
return None, load_gallery()
|
337 |
|
338 |
# Connect the generation button to both the image output and gallery update
|
339 |
def update_seed():
|