Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
•
d53d73c
1
Parent(s):
d9114d9
update
Browse files
app.py
CHANGED
@@ -19,6 +19,13 @@ from funcs import (
|
|
19 |
get_latent_z,
|
20 |
save_videos
|
21 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
def download_model():
|
24 |
REPO_ID = 'Doubiiu/DynamiCrafter_1024'
|
@@ -43,7 +50,7 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
|
|
43 |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
|
44 |
model = load_model_checkpoint(model, ckpt_path)
|
45 |
model.eval()
|
46 |
-
model = model.
|
47 |
save_fps = 8
|
48 |
|
49 |
seed_everything(seed)
|
@@ -51,7 +58,10 @@ def infer(image, prompt, steps=50, cfg_scale=7.5, eta=1.0, fs=3, seed=123):
|
|
51 |
transforms.Resize(min(resolution)),
|
52 |
transforms.CenterCrop(resolution),
|
53 |
])
|
54 |
-
|
|
|
|
|
|
|
55 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
56 |
start = time.time()
|
57 |
if steps > 60:
|
@@ -154,4 +164,4 @@ with gr.Blocks(analytics_enabled=False, css=css) as dynamicrafter_iface:
|
|
154 |
fn = infer
|
155 |
)
|
156 |
|
157 |
-
dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
|
|
|
19 |
get_latent_z,
|
20 |
save_videos
|
21 |
)
|
22 |
+
if torch.cuda.is_available():
|
23 |
+
device = "cuda"
|
24 |
+
elif torch.backends.mps.is_available():
|
25 |
+
device = "mps"
|
26 |
+
else:
|
27 |
+
device = "cpu"
|
28 |
+
|
29 |
|
30 |
def download_model():
|
31 |
REPO_ID = 'Doubiiu/DynamiCrafter_1024'
|
|
|
50 |
assert os.path.exists(ckpt_path), "Error: checkpoint Not Found!"
|
51 |
model = load_model_checkpoint(model, ckpt_path)
|
52 |
model.eval()
|
53 |
+
model = model.to(device)
|
54 |
save_fps = 8
|
55 |
|
56 |
seed_everything(seed)
|
|
|
58 |
transforms.Resize(min(resolution)),
|
59 |
transforms.CenterCrop(resolution),
|
60 |
])
|
61 |
+
if device == "cuda":
|
62 |
+
torch.cuda.empty_cache()
|
63 |
+
elif device == "mps":
|
64 |
+
torch.mps.empty_cache()
|
65 |
print('start:', prompt, time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())))
|
66 |
start = time.time()
|
67 |
if steps > 60:
|
|
|
164 |
fn = infer
|
165 |
)
|
166 |
|
167 |
+
dynamicrafter_iface.queue(max_size=12).launch(show_api=True)
|