wondervictor commited on
Commit
b1416e6
·
verified ·
1 Parent(s): fa0ea4a

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +4 -0
autoregressive/models/generate.py CHANGED
@@ -68,6 +68,10 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
68
  # mask = (probs == values).float()
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
 
 
 
 
71
  idx = torch.multinomial(probs, num_samples=1)
72
  else:
73
  _, idx = torch.topk(probs, k=1, dim=-1)
 
68
  # mask = (probs == values).float()
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
71
+ ### add to fix 'nan' and 'inf'
72
+ probs = torch.clamp(probs, min=0, max=None)
73
+ probs = probs / probs.sum()
74
+ ###
75
  idx = torch.multinomial(probs, num_samples=1)
76
  else:
77
  _, idx = torch.topk(probs, k=1, dim=-1)