wondervictor commited on
Commit
8a67645
·
verified ·
1 Parent(s): 8cdac78

Update autoregressive/models/generate.py

Browse files
Files changed (1) hide show
  1. autoregressive/models/generate.py +6 -4
autoregressive/models/generate.py CHANGED
@@ -140,10 +140,12 @@ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_int
140
  condition = condition.to(torch.float32)
141
  print(condition)
142
  if condition is not None:
143
- print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
144
- condition = model.adapter(condition)
145
- condition = model.adapter_mlp(condition)
146
- print(condition)
 
 
147
  if model.model_type == 'c2i':
148
  if cfg_scale > 1.0:
149
  cond_null = torch.ones_like(cond) * model.num_classes
 
140
  condition = condition.to(torch.float32)
141
  print(condition)
142
  if condition is not None:
143
+ with torch.no_grad():
144
+ print(model.adapter.model.embeddings.patch_embeddings.projection.weight)
145
+ condition = model.adapter(condition)
146
+ print(condition)
147
+ condition = model.adapter_mlp(condition)
148
+ print(condition)
149
  if model.model_type == 'c2i':
150
  if cfg_scale > 1.0:
151
  cond_null = torch.ones_like(cond) * model.num_classes