realantonvoronov commited on
Commit
1e17711
1 Parent(s): e5b0112

update sampling paramters in pipeline and arguments in app

Browse files
Files changed (2) hide show
  1. app.py +22 -2
  2. models/pipeline.py +13 -3
app.py CHANGED
@@ -27,11 +27,16 @@ def infer(
27
  more_smooth=True,
28
  smooth_start_si=2,
29
  turn_off_cfg_start_si=10,
 
 
30
  progress=gr.Progress(track_tqdm=True),
31
  ):
32
  if randomize_seed:
33
  seed = random.randint(0, MAX_SEED)
34
 
 
 
 
35
  image = pipe(
36
  prompt=prompt,
37
  null_prompt=negative_prompt,
@@ -41,6 +46,7 @@ def infer(
41
  more_smooth=more_smooth,
42
  smooth_start_si=smooth_start_si,
43
  turn_off_cfg_start_si=turn_off_cfg_start_si,
 
44
  seed=seed,
45
  )[0]
46
 
@@ -103,7 +109,7 @@ with gr.Blocks(css=css) as demo:
103
  minimum=0.0,
104
  maximum=10.,
105
  step=0.5,
106
- value=4.,
107
  )
108
 
109
  with gr.Accordion("Advanced Settings", open=False):
@@ -140,12 +146,24 @@ with gr.Blocks(css=css) as demo:
140
  value=2,
141
  )
142
  turn_off_cfg_start_si = gr.Slider(
143
- label="Disable CFG from scale",
144
  minimum=0,
145
  maximum=10,
146
  step=1,
147
  value=8,
148
  )
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
 
151
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
@@ -163,6 +181,8 @@ with gr.Blocks(css=css) as demo:
163
  more_smooth,
164
  smooth_start_si,
165
  turn_off_cfg_start_si,
 
 
166
  ],
167
  outputs=[result, seed],
168
  )
 
27
  more_smooth=True,
28
  smooth_start_si=2,
29
  turn_off_cfg_start_si=10,
30
+ more_diverse=True,
31
+ last_scale_temp=None,
32
  progress=gr.Progress(track_tqdm=True),
33
  ):
34
  if randomize_seed:
35
  seed = random.randint(0, MAX_SEED)
36
 
37
+
38
+ turn_on_cfg_start_si = 2 if more_diverse else 0
39
+
40
  image = pipe(
41
  prompt=prompt,
42
  null_prompt=negative_prompt,
 
46
  more_smooth=more_smooth,
47
  smooth_start_si=smooth_start_si,
48
  turn_off_cfg_start_si=turn_off_cfg_start_si,
49
+ turn_on_cfg_start_si=turn_on_cfg_start_si,
50
  seed=seed,
51
  )[0]
52
 
 
109
  minimum=0.0,
110
  maximum=10.,
111
  step=0.5,
112
+ value=6.,
113
  )
114
 
115
  with gr.Accordion("Advanced Settings", open=False):
 
146
  value=2,
147
  )
148
  turn_off_cfg_start_si = gr.Slider(
149
+ label="Disable CFG starting scale",
150
  minimum=0,
151
  maximum=10,
152
  step=1,
153
  value=8,
154
  )
155
+ with gr.Row():
156
+ more_diverse = gr.Checkbox(label="More diverse", value=True)
157
+ apply_late_temperature = gr.Checkbox(label="Temperature after disabling CFG", value=False)
158
+ last_scale_temp = gr.Slider(
159
+ label="Late temperature value",
160
+ minimum=0.1,
161
+ maximum=10,
162
+ step=0.1,
163
+ value=1,
164
+ )
165
+ if not apply_late_temperature:
166
+ last_scale_temp = None
167
 
168
 
169
  gr.Examples(examples=examples, inputs=[prompt], outputs=[result, seed], fn=infer, cache_examples=True)# cache_mode="lazy")
 
181
  more_smooth,
182
  smooth_start_si,
183
  turn_off_cfg_start_si,
184
+ more_diverse,
185
+ last_scale_temp,
186
  ],
187
  outputs=[result, seed],
188
  )
models/pipeline.py CHANGED
@@ -91,7 +91,9 @@ class SwittiPipeline:
91
  return_pil: bool = True,
92
  smooth_start_si: int = 0,
93
  turn_off_cfg_start_si: int = 10,
 
94
  image_size: tuple[int, int] = (512, 512),
 
95
  ) -> torch.Tensor | list[PILImage]:
96
  """
97
  only used for inference, on autoregressive mode
@@ -155,7 +157,8 @@ class SwittiPipeline:
155
  else:
156
  freqs_cis = switti.freqs_cis
157
 
158
- if si >= turn_off_cfg_start_si:
 
159
  x_BLC = x_BLC[:B]
160
  context = context[:B]
161
  context_attn_bias = context_attn_bias[:B]
@@ -170,6 +173,8 @@ class SwittiPipeline:
170
  if b.cross_attn.caching and b.cross_attn.cached_k is not None:
171
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
172
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
 
 
173
 
174
  for block in switti.blocks:
175
  x_BLC = block(
@@ -186,11 +191,16 @@ class SwittiPipeline:
186
  logits_BlV = switti.get_logits(x_BLC, cond_BD)
187
 
188
  # Guidance
189
- if si < turn_off_cfg_start_si:
 
 
 
190
  t = cfg
191
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
 
 
192
 
193
- if more_smooth and si >= smooth_start_si:
194
  # not used when evaluating FID/IS/Precision/Recall
195
  gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
196
  idx_Bl = gumbel_softmax_with_rng(
 
91
  return_pil: bool = True,
92
  smooth_start_si: int = 0,
93
  turn_off_cfg_start_si: int = 10,
94
+ turn_on_cfg_start_si: int = 0,
95
  image_size: tuple[int, int] = (512, 512),
96
+ last_scale_temp: None | float = None,
97
  ) -> torch.Tensor | list[PILImage]:
98
  """
99
  only used for inference, on autoregressive mode
 
157
  else:
158
  freqs_cis = switti.freqs_cis
159
 
160
+ if si < turn_on_cfg_start_si or si >= turn_off_cfg_start_si:
161
+ apply_smooth = False
162
  x_BLC = x_BLC[:B]
163
  context = context[:B]
164
  context_attn_bias = context_attn_bias[:B]
 
173
  if b.cross_attn.caching and b.cross_attn.cached_k is not None:
174
  b.cross_attn.cached_k = b.cross_attn.cached_k[:B]
175
  b.cross_attn.cached_v = b.cross_attn.cached_v[:B]
176
+ else:
177
+ apply_smooth = more_smooth
178
 
179
  for block in switti.blocks:
180
  x_BLC = block(
 
191
  logits_BlV = switti.get_logits(x_BLC, cond_BD)
192
 
193
  # Guidance
194
+ if si < turn_on_cfg_start_si:
195
+ t = 0 # no guidance
196
+ elif si >= turn_on_cfg_start_si and si < turn_off_cfg_start_si:
197
+ # default const cfg
198
  t = cfg
199
  logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
200
+ elif last_scale_temp is not None:
201
+ logits_BlV = logits_BlV / last_scale_temp
202
 
203
+ if apply_smooth and si >= smooth_start_si:
204
  # not used when evaluating FID/IS/Precision/Recall
205
  gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
206
  idx_Bl = gumbel_softmax_with_rng(