Fabrice-TIERCELIN commited on
Commit
93e1d77
1 Parent(s): 0879b61

Fix function

Browse files
Files changed (1) hide show
  1. gradio_demo.py +15 -17
gradio_demo.py CHANGED
@@ -160,7 +160,6 @@ def stage2_process(
160
  if 1 < downscale:
161
  input_height, input_width, input_channel = np.array(input_image).shape
162
  input_image = input_image.resize((input_width // downscale, input_height // downscale), Image.LANCZOS)
163
- torch.cuda.set_device(SUPIR_device)
164
  event_id = str(time.time_ns())
165
  event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
166
  'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
@@ -181,23 +180,8 @@ def stage2_process(
181
  input_image = upscale_image(input_image, upscale, unit_resolution=32,
182
  min_size=min_size)
183
 
184
- LQ = np.array(input_image) / 255.0
185
- LQ = np.power(LQ, gamma_correction)
186
- LQ *= 255.0
187
- LQ = LQ.round().clip(0, 255).astype(np.uint8)
188
- LQ = LQ / 255 * 2 - 1
189
- LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
190
- if use_llava:
191
- captions = [prompt]
192
- else:
193
- captions = ['']
194
-
195
- model.ae_dtype = convert_dtype(ae_dtype)
196
- model.model.dtype = convert_dtype(diff_dtype)
197
-
198
  samples = restore(
199
  model,
200
- LQ,
201
  captions,
202
  edm_steps,
203
  s_stage1,
@@ -255,7 +239,6 @@ def stage2_process(
255
  @spaces.GPU(duration=600)
256
  def restore(
257
  model,
258
- LQ,
259
  captions,
260
  edm_steps,
261
  s_stage1,
@@ -273,6 +256,21 @@ def restore(
273
  spt_linear_CFG,
274
  spt_linear_s_stage2
275
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  return model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
277
  s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
278
  num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,
 
160
  if 1 < downscale:
161
  input_height, input_width, input_channel = np.array(input_image).shape
162
  input_image = input_image.resize((input_width // downscale, input_height // downscale), Image.LANCZOS)
 
163
  event_id = str(time.time_ns())
164
  event_dict = {'event_id': event_id, 'localtime': time.ctime(), 'prompt': prompt, 'a_prompt': a_prompt,
165
  'n_prompt': n_prompt, 'num_samples': num_samples, 'upscale': upscale, 'edm_steps': edm_steps,
 
180
  input_image = upscale_image(input_image, upscale, unit_resolution=32,
181
  min_size=min_size)
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  samples = restore(
184
  model,
 
185
  captions,
186
  edm_steps,
187
  s_stage1,
 
239
  @spaces.GPU(duration=600)
240
  def restore(
241
  model,
 
242
  captions,
243
  edm_steps,
244
  s_stage1,
 
256
  spt_linear_CFG,
257
  spt_linear_s_stage2
258
  ):
259
+ torch.cuda.set_device(SUPIR_device)
260
+ LQ = np.array(input_image) / 255.0
261
+ LQ = np.power(LQ, gamma_correction)
262
+ LQ *= 255.0
263
+ LQ = LQ.round().clip(0, 255).astype(np.uint8)
264
+ LQ = LQ / 255 * 2 - 1
265
+ LQ = torch.tensor(LQ, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
266
+ if use_llava:
267
+ captions = [prompt]
268
+ else:
269
+ captions = ['']
270
+
271
+ model.ae_dtype = convert_dtype(ae_dtype)
272
+ model.model.dtype = convert_dtype(diff_dtype)
273
+
274
  return model.batchify_sample(LQ, captions, num_steps=edm_steps, restoration_scale=s_stage1, s_churn=s_churn,
275
  s_noise=s_noise, cfg_scale=s_cfg, control_scale=s_stage2, seed=seed,
276
  num_samples=num_samples, p_p=a_prompt, n_p=n_prompt, color_fix_type=color_fix_type,