wondervictor commited on
Commit
31086e2
·
verified ·
1 Parent(s): 6c21dc4

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +2 -3
autoregressive/models/generate.py CHANGED
@@ -90,9 +90,6 @@ def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: i
90
 
91
  def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
92
  if cfg_scale > 1.0:
93
- print(cond_idx)
94
- print(input_pos)
95
- print(condition)
96
  logits, _ = model(None, cond_idx, input_pos, condition=condition)
97
  logits_combined = logits
98
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
@@ -142,9 +139,11 @@ def decode_n_tokens(
142
 
143
  @torch.no_grad()
144
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
 
145
  if condition is not None:
146
  condition = model.adapter(condition)
147
  condition = model.adapter_mlp(condition)
 
148
  if model.model_type == 'c2i':
149
  if cfg_scale > 1.0:
150
  cond_null = torch.ones_like(cond) * model.num_classes
 
90
 
91
  def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, condition:torch.Tensor, **sampling_kwargs):
92
  if cfg_scale > 1.0:
 
 
 
93
  logits, _ = model(None, cond_idx, input_pos, condition=condition)
94
  logits_combined = logits
95
  cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
 
139
 
140
  @torch.no_grad()
141
  def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, condition=None, condition_null=None, condition_token_nums=0, **sampling_kwargs):
142
+ print(condition)
143
  if condition is not None:
144
  condition = model.adapter(condition)
145
  condition = model.adapter_mlp(condition)
146
+ print(condition)
147
  if model.model_type == 'c2i':
148
  if cfg_scale > 1.0:
149
  cond_null = torch.ones_like(cond) * model.num_classes