Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -71,8 +71,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
|
|
71 |
# add to fix 'nan' and 'inf'
|
72 |
probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
|
73 |
probs = torch.clamp(probs, min=0, max=None)
|
74 |
-
print(f'inf:{torch.any(torch.isinf(probs))}')
|
75 |
-
print(f'nan: {torch.any(torch.isnan(probs))}')
|
76 |
|
77 |
idx = torch.multinomial(probs, num_samples=1)
|
78 |
else:
|
@@ -139,6 +137,7 @@ def decode_n_tokens(
|
|
139 |
|
140 |
@torch.no_grad()
|
141 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
|
|
142 |
print(condition)
|
143 |
if condition is not None:
|
144 |
condition = model.adapter(condition)
|
|
|
71 |
# add to fix 'nan' and 'inf'
|
72 |
probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
|
73 |
probs = torch.clamp(probs, min=0, max=None)
|
|
|
|
|
74 |
|
75 |
idx = torch.multinomial(probs, num_samples=1)
|
76 |
else:
|
|
|
137 |
|
138 |
@torch.no_grad()
|
139 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
140 |
+
condition = condition.to(torch.float32)
|
141 |
print(condition)
|
142 |
if condition is not None:
|
143 |
condition = model.adapter(condition)
|