wondervictor commited on
Commit
355654c
·
verified ·
1 Parent(s): 4fec13e

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +1 -1
autoregressive/models/generate.py CHANGED
@@ -69,7 +69,7 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
71
  # add to fix 'nan' and 'inf'
72
- # probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
73
  probs = torch.clamp(probs, min=0, max=None)
74
  # probs = probs / probs.sum()
75
  print(f'inf:{torch.any(torch.isinf(probs))}')
 
69
  # probs = probs * (1 - mask)
70
  if sample_logits:
71
  # add to fix 'nan' and 'inf'
72
+ probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
73
  probs = torch.clamp(probs, min=0, max=None)
74
  # probs = probs / probs.sum()
75
  print(f'inf:{torch.any(torch.isinf(probs))}')