Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -68,10 +68,10 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
|
|
68 |
# mask = (probs == values).float()
|
69 |
# probs = probs * (1 - mask)
|
70 |
if sample_logits:
|
71 |
-
|
72 |
probs = torch.clamp(probs, min=0, max=None)
|
73 |
probs = probs / probs.sum()
|
74 |
-
|
75 |
idx = torch.multinomial(probs, num_samples=1)
|
76 |
else:
|
77 |
_, idx = torch.topk(probs, k=1, dim=-1)
|
|
|
68 |
# mask = (probs == values).float()
|
69 |
# probs = probs * (1 - mask)
|
70 |
if sample_logits:
|
71 |
+
# add to fix 'nan' and 'inf'
|
72 |
probs = torch.clamp(probs, min=0, max=None)
|
73 |
probs = probs / probs.sum()
|
74 |
+
|
75 |
idx = torch.multinomial(probs, num_samples=1)
|
76 |
else:
|
77 |
_, idx = torch.topk(probs, k=1, dim=-1)
|