multimodalart HF staff commited on
Commit
702754c
1 Parent(s): dae6484

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -0
app.py CHANGED
@@ -7,6 +7,8 @@ from huggingface_hub import snapshot_download
7
  from pyramid_dit import PyramidDiTForVideoGeneration
8
  from diffusers.utils import export_to_video
9
 
 
 
10
  # Constants
11
  MODEL_PATH = "pyramid-flow-model"
12
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
@@ -35,6 +37,7 @@ def load_model():
35
  model = load_model()
36
 
37
  # Text-to-video generation function
 
38
  def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
39
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
40
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
@@ -58,6 +61,7 @@ def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
58
  return output_path
59
 
60
  # Image-to-video generation function
 
61
  def generate_video_from_image(image, prompt, duration, video_guidance_scale):
62
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
63
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
 
7
  from pyramid_dit import PyramidDiTForVideoGeneration
8
  from diffusers.utils import export_to_video
9
 
10
+ import spaces
11
+
12
  # Constants
13
  MODEL_PATH = "pyramid-flow-model"
14
  MODEL_REPO = "rain1011/pyramid-flow-sd3"
 
37
  model = load_model()
38
 
39
  # Text-to-video generation function
40
+ @spaces.GPU(duration=240)
41
  def generate_video(prompt, duration, guidance_scale, video_guidance_scale):
42
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
43
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32
 
61
  return output_path
62
 
63
  # Image-to-video generation function
64
+ @spaces.GPU(duration=240)
65
  def generate_video_from_image(image, prompt, duration, video_guidance_scale):
66
  temp = int(duration * 2.4) # Convert seconds to temp value (assuming 24 FPS)
67
  torch_dtype = torch.bfloat16 if MODEL_DTYPE == "bf16" else torch.float32