Spaces:
Running
on
Zero
Running
on
Zero
wondervictor
commited on
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
@@ -90,9 +90,6 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
|
|
90 |
|
91 |
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
92 |
if cfg_scale > 1.0:
|
93 |
-
print(cond_idx)
|
94 |
-
print(input_pos)
|
95 |
-
print(condition)
|
96 |
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
97 |
logits_combined = logits
|
98 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
@@ -142,9 +139,11 @@ def decode_n_tokens(
|
|
142 |
|
143 |
@torch.no_grad()
|
144 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
|
|
145 |
if condition is not None:
|
146 |
condition = model.adapter(condition)
|
147 |
condition = model.adapter_mlp(condition)
|
|
|
148 |
if model.model_type == 'c2i':
|
149 |
if cfg_scale > 1.0:
|
150 |
cond_null = torch.ones_like(cond) * model.num_classes
|
|
|
90 |
|
91 |
def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
|
92 |
if cfg_scale > 1.0:
|
|
|
|
|
|
|
93 |
logits, _ = model(None, cond_idx, input_pos, condition=condition)
|
94 |
logits_combined = logits
|
95 |
cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
|
|
|
139 |
|
140 |
@torch.no_grad()
|
141 |
def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
|
142 |
+
print(condition)
|
143 |
if condition is not None:
|
144 |
condition = model.adapter(condition)
|
145 |
condition = model.adapter_mlp(condition)
|
146 |
+
print(condition)
|
147 |
if model.model_type == 'c2i':
|
148 |
if cfg_scale > 1.0:
|
149 |
cond_null = torch.ones_like(cond) * model.num_classes
|