wondervictor commited on
Commit
eb34ac9
·
verified ·
1 Parent(s): 31086e2

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +1 -2
autoregressive/models/generate.py CHANGED
@@ -71,8 +71,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
71
  # add to fix 'nan' and 'inf'
72
  probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
73
  probs = torch.clamp(probs, min=0, max=None)
74
- print(f'inf:{torch.any(torch.isinf(probs))}')
75
- print(f'nan: {torch.any(torch.isnan(probs))}')
76
 
77
  idx = torch.multinomial(probs, num_samples=1)
78
  else:
@@ -139,6 +137,7 @@ def decode_n_tokens(
139
 
140
  @torch.no_grad()
141
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
 
142
  print(condition)
143
  if condition is not None:
144
  condition = model.adapter(condition)
 
71
  # add to fix 'nan' and 'inf'
72
  probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
73
  probs = torch.clamp(probs, min=0, max=None)
 
 
74
 
75
  idx = torch.multinomial(probs, num_samples=1)
76
  else:
 
137
 
138
  @torch.no_grad()
139
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
140
+ condition = condition.to(torch.float32)
141
  print(condition)
142
  if condition is not None:
143
  condition = model.adapter(condition)