fffiloni commited on
Commit
a24bb36
1 Parent(s): 3a8fba4

Update app_gradio.py

Browse files
Files changed (1) hide show
  1. app_gradio.py +30 -19
app_gradio.py CHANGED
@@ -24,6 +24,7 @@ from invert_utils import ddim_inversion as dd_inversion
24
  from gifs_filter import filter
25
  import subprocess
26
  import uuid
 
27
 
28
  from huggingface_hub import snapshot_download
29
 
@@ -159,34 +160,44 @@ def generate_output(image, apply_filter, prompt: str, num_seeds: int = 3, lambda
159
  if prompt is None:
160
  raise gr.Error("You forgot to describe the motion !")
161
  """Main function to generate output GIFs"""
162
- unique_id = str(uuid.uuid4())
163
- exp_dir = f"static/app_tmp_{unique_id}"
 
164
  os.makedirs(exp_dir, exist_ok=True)
165
 
166
  # Save the input image temporarily
167
- temp_image_path = os.path.join(exp_dir, "temp_input.png")
 
168
 
169
  image = Image.open(image)
170
  image = image.resize((256, 256), Image.LANCZOS)
171
 
172
  image.save(temp_image_path)
173
 
174
- # Generate the GIFs
175
- generated_gifs = process_video(
176
- num_frames=10,
177
- num_seeds=num_seeds,
178
- generator=None,
179
- exp_dir=exp_dir,
180
- load_name=temp_image_path,
181
- caption=prompt,
182
- lambda_=1 - lambda_value
183
- )
184
-
185
- if apply_filter == True:
186
- print("APPLYING FILTER")
187
- # Apply filtering (assuming filter function is imported)
188
- filtered_gifs = filter(generated_gifs, temp_image_path)
189
- return filtered_gifs, filtered_gifs
 
 
 
 
 
 
 
 
190
  else:
191
  print("NOT APPLYING FILTER")
192
  return generated_gifs, generated_gifs
 
24
  from gifs_filter import filter
25
  import subprocess
26
  import uuid
27
+ import tempfile
28
 
29
  from huggingface_hub import snapshot_download
30
 
 
160
  if prompt is None:
161
  raise gr.Error("You forgot to describe the motion !")
162
  """Main function to generate output GIFs"""
163
+
164
+ # Create a temporary directory for this session
165
+ exp_dir = tempfile.mkdtemp(prefix="app_tmp_")
166
  os.makedirs(exp_dir, exist_ok=True)
167
 
168
  # Save the input image temporarily
169
+ unique_id = str(uuid.uuid4())
170
+ temp_image_path = os.path.join(exp_dir, f"temp_input_{unique_id}.png")
171
 
172
  image = Image.open(image)
173
  image = image.resize((256, 256), Image.LANCZOS)
174
 
175
  image.save(temp_image_path)
176
 
177
+ try:
178
+ # Attempt to process video
179
+ generated_gifs = process_video(
180
+ num_frames=10,
181
+ num_seeds=num_seeds,
182
+ generator=None,
183
+ exp_dir=exp_dir,
184
+ load_name=temp_image_path,
185
+ caption=prompt,
186
+ lambda_=1 - lambda_value
187
+ )
188
+ except Exception as e:
189
+ torch.cuda.empty_cache() # Clear CUDA cache in case of failure
190
+ raise gr.Error(f"Video processing failed: {str(e)}") from e
191
+
192
+ if apply_filter:
193
+ try:
194
+ print("APPLYING FILTER")
195
+ # Attempt to apply filtering
196
+ filtered_gifs = filter(generated_gifs, temp_image_path)
197
+ return filtered_gifs, filtered_gifs
198
+ except Exception as e:
199
+ torch.cuda.empty_cache() # Clear CUDA cache in case of failure
200
+ raise gr.Error(f"Filtering failed: {str(e)}") from e
201
  else:
202
  print("NOT APPLYING FILTER")
203
  return generated_gifs, generated_gifs