wondervictor commited on
Commit
f3e5467
·
verified ·
1 Parent(s): fbb8b6f

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +4 -4
autoregressive/models/generate.py CHANGED
@@ -87,9 +87,9 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
87
  return probs
88
 
89
 
90
- def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
91
  if cfg_scale > 1.0:
92
- logits, _ = model(None, cond_idx, input_pos, condition=condition)
93
  logits_combined = logits
94
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
95
  logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
@@ -137,7 +137,7 @@ def decode_n_tokens(
137
 
138
 
139
  @torch.no_grad()
140
- def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
141
  # print("cond", torch.any(torch.isnan(cond)))
142
  if condition is not None:
143
  with torch.no_grad():
@@ -207,7 +207,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
207
  # create an empty tensor of the expected final shape and fill in the current tokens
208
  seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
209
  input_pos = torch.arange(0, T, device=device)
210
- next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, **sampling_kwargs)
211
  seq[:, T:T+1] = next_token
212
 
213
  input_pos = torch.tensor([T], device=device, dtype=torch.int)
 
87
  return probs
88
 
89
 
90
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, control_strength: float=1, **sampling_kwargs):
91
  if cfg_scale > 1.0:
92
+ logits, _ = model(None, cond_idx, input_pos, condition=condition, control_strength=control_strength)
93
  logits_combined = logits
94
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
95
  logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
 
137
 
138
 
139
  @torch.no_grad()
140
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, control_strength=1, **sampling_kwargs):
141
  # print("cond", torch.any(torch.isnan(cond)))
142
  if condition is not None:
143
  with torch.no_grad():
 
207
  # create an empty tensor of the expected final shape and fill in the current tokens
208
  seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
209
  input_pos = torch.arange(0, T, device=device)
210
+ next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
211
  seq[:, T:T+1] = next_token
212
 
213
  input_pos = torch.tensor([T], device=device, dtype=torch.int)