Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -87,9 +87,9 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
|
|
87 |
return probs
|
88 |
|
89 |
|
90 |
-
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
91 |
if cfg_scale > 1.0:
|
92 |
-
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
93 |
logits_combined = logits
|
94 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
95 |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
@@ -137,7 +137,7 @@ def decode_n_tokens(
|
|
137 |
|
138 |
|
139 |
@torch.no_grad()
|
140 |
-
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
141 |
# print("cond", torch.any(torch.isnan(cond)))
|
142 |
if condition is not None:
|
143 |
with torch.no_grad():
|
@@ -207,7 +207,7 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
|
|
207 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
208 |
seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
|
209 |
input_pos = torch.arange(0, T, device=device)
|
210 |
-
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined,
|
211 |
seq[:, T:T+1] = next_token
|
212 |
|
213 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|
|
|
87 |
return probs
|
88 |
|
89 |
|
90 |
+
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, control_strength: float=1, **sampling_kwargs):
|
91 |
if cfg_scale > 1.0:
|
92 |
+
logits, _ = model(None, cond_idx, input_pos, condition=condition, control_strength=control_strength)
|
93 |
logits_combined = logits
|
94 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
95 |
logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
|
|
|
137 |
|
138 |
|
139 |
@torch.no_grad()
|
140 |
+
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, control_strength=1, **sampling_kwargs):
|
141 |
# print("cond", torch.any(torch.isnan(cond)))
|
142 |
if condition is not None:
|
143 |
with torch.no_grad():
|
|
|
207 |
# create an empty tensor of the expected final shape and fill in the current tokens
|
208 |
seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
|
209 |
input_pos = torch.arange(0, T, device=device)
|
210 |
+
next_token = prefill(model, cond_combined, input_pos, cfg_scale, condition_combined, control_strength,**sampling_kwargs)
|
211 |
seq[:, T:T+1] = next_token
|
212 |
|
213 |
input_pos = torch.tensor([T], device=device, dtype=torch.int)
|