Spaces:
Runtime error
Runtime error
ShoufaChen
commited on
Commit
•
e496e33
1
Parent(s):
e556404
minor
Browse files
app.py
CHANGED
@@ -76,7 +76,7 @@ def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
|
|
76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
77 |
|
78 |
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
79 |
-
if
|
80 |
index_sample = index_sample[:len(class_labels)]
|
81 |
t2 = time.time()
|
82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|
|
|
76 |
print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
|
77 |
|
78 |
index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
|
79 |
+
if cfg_scale > 1.0:
|
80 |
index_sample = index_sample[:len(class_labels)]
|
81 |
t2 = time.time()
|
82 |
samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
|