wondervictor commited on
Commit
24c3c11
·
verified ·
1 Parent(s): 92ca1db

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +2 -2
autoregressive/models/generate.py CHANGED
@@ -60,8 +60,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
60
  logits = logits[:, -1, :] / max(temperature, 1e-5)
61
  if top_k > 0 or top_p < 1.0:
62
  logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
63
- print(logits.sum())
64
- print(logits)
65
  probs = F.softmax(logits, dim=-1)
66
  # values, indices = torch.max(probs, dim=1, keepdim=True)
67
  # mask = (probs == values).float()
@@ -93,6 +91,8 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
93
  def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
94
  if cfg_scale > 1.0:
95
  logits, _ = model(None, cond_idx, input_pos, condition=condition)
 
 
96
  logits_combined = logits
97
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
98
  logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
 
60
  logits = logits[:, -1, :] / max(temperature, 1e-5)
61
  if top_k > 0 or top_p < 1.0:
62
  logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
 
 
63
  probs = F.softmax(logits, dim=-1)
64
  # values, indices = torch.max(probs, dim=1, keepdim=True)
65
  # mask = (probs == values).float()
 
91
  def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
92
  if cfg_scale > 1.0:
93
  logits, _ = model(None, cond_idx, input_pos, condition=condition)
94
+ print(logits.sum())
95
+ print(logits)
96
  logits_combined = logits
97
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
98
  logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale