realantonvoronov commited on
Commit
e5b0112
1 Parent(s): 94cd78d

fix cfg to const

Browse files
Files changed (1) hide show
  1. models/pipeline.py +3 -3
models/pipeline.py CHANGED
@@ -164,10 +164,10 @@ class SwittiPipeline:
164
  if crop_cond is not None:
165
  crop_cond = crop_cond[:B]
166
  for b in switti.blocks:
167
- if b.attn.caching:
168
  b.attn.cached_k = b.attn.cached_k[:B]
169
  b.attn.cached_v = b.attn.cached_v[:B]
170
- if b.cross_attn.caching:
171
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
172
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
173
 
@@ -187,7 +187,7 @@ class SwittiPipeline:
187
 
188
  # Guidance
189
  if si < turn_off_cfg_start_si:
190
- t = cfg * ratio
191
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
192
 
193
  if more_smooth and si >= smooth_start_si:
 
164
  if crop_cond is not None:
165
  crop_cond = crop_cond[:B]
166
  for b in switti.blocks:
167
+ if b.attn.caching and b.attn.cached_k is not None:
168
  b.attn.cached_k = b.attn.cached_k[:B]
169
  b.attn.cached_v = b.attn.cached_v[:B]
170
+ if b.cross_attn.caching and b.cross_attn.cached_k is not None:
171
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
172
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
173
 
 
187
 
188
  # Guidance
189
  if si < turn_off_cfg_start_si:
190
+ t = cfg
191
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
192
 
193
  if more_smooth and si >= smooth_start_si: