tastelikefeet commited on
Commit
95878c0
1 Parent(s): ddfd056
app.py CHANGED
@@ -13,12 +13,41 @@ import re
13
  from gradio.components import Component
14
  from util import check_channels, resize_image, save_images
15
  import json
 
 
16
 
17
  BBOX_MAX_NUM = 8
18
  img_save_folder = 'SaveImages'
19
  load_model = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  if load_model:
21
- inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.0')
22
 
23
 
24
  def count_lines(prompt):
@@ -221,7 +250,8 @@ with block:
221
  [<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
222
  [<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
223
  [<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
224
- version: 1.1.0 </div>')
 
225
  with gr.Row(variant='compact'):
226
  with gr.Column():
227
  with gr.Accordion('🕹Instructions(说明)', open=False,):
@@ -305,7 +335,7 @@ with block:
305
  rect_xywh_list.extend([x, y, w, h])
306
 
307
  rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
308
- draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=60)
309
 
310
  def re_draw():
311
  return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
@@ -357,7 +387,7 @@ with block:
357
  ori_img = gr.Image(label='Ori(原图)')
358
 
359
  def upload_ref(x):
360
- return [gr.Image(type="numpy", brush_radius=60, tool='sketch'),
361
  gr.Image(value=x)]
362
 
363
  def clear_ref(x):
@@ -394,8 +424,8 @@ with block:
394
  run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
395
 
396
  block.launch(
397
- #server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
398
- #share=False,
399
  root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
400
  )
401
  # block.launch(server_name='0.0.0.0')
 
13
  from gradio.components import Component
14
  from util import check_channels, resize_image, save_images
15
  import json
16
+ import argparse
17
+
18
 
19
  BBOX_MAX_NUM = 8
20
  img_save_folder = 'SaveImages'
21
  load_model = True
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--use_fp32",
28
+ action="store_true",
29
+ default=False,
30
+ help="Whether or not to use fp32 during inference."
31
+ )
32
+ parser.add_argument(
33
+ "--no_translator",
34
+ action="store_true",
35
+ default=False,
36
+ help="Whether or not to use the CH->EN translator, which enable input Chinese prompt and cause ~4GB VRAM."
37
+ )
38
+ parser.add_argument(
39
+ "--font_path",
40
+ type=str,
41
+ default='font/Arial_Unicode.ttf',
42
+ help="path of a font file"
43
+ )
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ args = parse_args()
49
  if load_model:
50
+ inference = pipeline('my-anytext-task', model='damo/cv_anytext_text_generation_editing', model_revision='v1.1.1', use_fp16=not args.use_fp32, use_translator=not args.no_translator, font_path=args.font_path)
51
 
52
 
53
  def count_lines(prompt):
 
250
  [<a href="https://arxiv.org/abs/2311.03054" style="color:blue; font-size:18px;">arXiv</a>] \
251
  [<a href="https://github.com/tyxsspa/AnyText" style="color:blue; font-size:18px;">Code</a>] \
252
  [<a href="https://modelscope.cn/models/damo/cv_anytext_text_generation_editing/summary" style="color:blue; font-size:18px;">ModelScope</a>]\
253
+ [<a href="https://huggingface.co/spaces/modelscope/AnyText" style="color:blue; font-size:18px;">HuggingFace</a>]\
254
+ version: 1.1.1 </div>')
255
  with gr.Row(variant='compact'):
256
  with gr.Column():
257
  with gr.Accordion('🕹Instructions(说明)', open=False,):
 
335
  rect_xywh_list.extend([x, y, w, h])
336
 
337
  rect_img = gr.Image(value=create_canvas(), label="Rext Position(方框位置)", elem_id="MD-bbox-rect-t2i", show_label=False, visible=False)
338
+ draw_img = gr.Image(value=create_canvas(), label="Draw Position(绘制位置)", visible=True, tool='sketch', show_label=False, brush_radius=100)
339
 
340
  def re_draw():
341
  return [gr.Image(value=create_canvas(), tool='sketch'), gr.Slider(value=512), gr.Slider(value=512)]
 
387
  ori_img = gr.Image(label='Ori(原图)')
388
 
389
  def upload_ref(x):
390
+ return [gr.Image(type="numpy", brush_radius=100, tool='sketch'),
391
  gr.Image(value=x)]
392
 
393
  def clear_ref(x):
 
424
  run_edit.click(fn=process, inputs=[gr.State('edit')] + ips, outputs=[result_gallery, result_info])
425
 
426
  block.launch(
427
+ # server_name='0.0.0.0' if os.getenv('GRADIO_LISTEN', '') != '' else "127.0.0.1",
428
+ # share=False,
429
  root_path=f"/{os.getenv('GRADIO_PROXY_PATH')}" if os.getenv('GRADIO_PROXY_PATH') else ""
430
  )
431
  # block.launch(server_name='0.0.0.0')
cldm/cldm.py CHANGED
@@ -32,6 +32,8 @@ class ControlledUnetModel(UNetModel):
32
  hs = []
33
  with torch.no_grad():
34
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
 
 
35
  emb = self.time_embed(t_emb)
36
  h = x.type(self.dtype)
37
  for module in self.input_blocks:
@@ -124,12 +126,12 @@ class ControlNet(nn.Module):
124
  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
125
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
126
  f"attention will still not be set.")
127
-
128
  self.attention_resolutions = attention_resolutions
129
  self.dropout = dropout
130
  self.channel_mult = channel_mult
131
  self.conv_resample = conv_resample
132
  self.use_checkpoint = use_checkpoint
 
133
  self.dtype = th.float16 if use_fp16 else th.float32
134
  self.num_heads = num_heads
135
  self.num_head_channels = num_head_channels
@@ -313,6 +315,8 @@ class ControlNet(nn.Module):
313
 
314
  def forward(self, x, hint, text_info, timesteps, context, **kwargs):
315
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
 
 
316
  emb = self.time_embed(t_emb)
317
 
318
  # guided_hint from text_info
@@ -344,6 +348,7 @@ class ControlNet(nn.Module):
344
  class ControlLDM(LatentDiffusion):
345
 
346
  def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
 
347
  super().__init__(*args, **kwargs)
348
  self.control_model = instantiate_from_config(control_stage_config)
349
  self.control_key = control_key
@@ -356,6 +361,7 @@ class ControlLDM(LatentDiffusion):
356
  self.with_step_weight = with_step_weight
357
  self.use_vae_upsample = use_vae_upsample
358
  self.latin_weight = latin_weight
 
359
  if embedding_manager_config is not None and embedding_manager_config.params.valid:
360
  self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
361
  for param in self.embedding_manager.embedding_parameters():
@@ -369,6 +375,7 @@ class ControlLDM(LatentDiffusion):
369
  args.rec_image_shape = "3, 48, 320"
370
  args.rec_batch_num = 6
371
  args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
 
372
  self.cn_recognizer = TextRecognizer(args, self.text_predictor)
373
  for param in self.text_predictor.parameters():
374
  param.requires_grad = False
@@ -433,6 +440,8 @@ class ControlLDM(LatentDiffusion):
433
  diffusion_model = self.model.diffusion_model
434
  _cond = torch.cat(cond['c_crossattn'], 1)
435
  _hint = torch.cat(cond['c_concat'], 1)
 
 
436
  control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
437
  control = [c * scale for c, scale in zip(control, self.control_scales)]
438
  eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
 
32
  hs = []
33
  with torch.no_grad():
34
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
35
+ if self.use_fp16:
36
+ t_emb = t_emb.half()
37
  emb = self.time_embed(t_emb)
38
  h = x.type(self.dtype)
39
  for module in self.input_blocks:
 
126
  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
127
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
128
  f"attention will still not be set.")
 
129
  self.attention_resolutions = attention_resolutions
130
  self.dropout = dropout
131
  self.channel_mult = channel_mult
132
  self.conv_resample = conv_resample
133
  self.use_checkpoint = use_checkpoint
134
+ self.use_fp16 = use_fp16
135
  self.dtype = th.float16 if use_fp16 else th.float32
136
  self.num_heads = num_heads
137
  self.num_head_channels = num_head_channels
 
315
 
316
  def forward(self, x, hint, text_info, timesteps, context, **kwargs):
317
  t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
318
+ if self.use_fp16:
319
+ t_emb = t_emb.half()
320
  emb = self.time_embed(t_emb)
321
 
322
  # guided_hint from text_info
 
348
  class ControlLDM(LatentDiffusion):
349
 
350
  def __init__(self, control_stage_config, control_key, glyph_key, position_key, only_mid_control, loss_alpha=0, loss_beta=0, with_step_weight=False, use_vae_upsample=False, latin_weight=1.0, embedding_manager_config=None, *args, **kwargs):
351
+ self.use_fp16 = kwargs.pop('use_fp16', False)
352
  super().__init__(*args, **kwargs)
353
  self.control_model = instantiate_from_config(control_stage_config)
354
  self.control_key = control_key
 
361
  self.with_step_weight = with_step_weight
362
  self.use_vae_upsample = use_vae_upsample
363
  self.latin_weight = latin_weight
364
+
365
  if embedding_manager_config is not None and embedding_manager_config.params.valid:
366
  self.embedding_manager = self.instantiate_embedding_manager(embedding_manager_config, self.cond_stage_model)
367
  for param in self.embedding_manager.embedding_parameters():
 
375
  args.rec_image_shape = "3, 48, 320"
376
  args.rec_batch_num = 6
377
  args.rec_char_dict_path = './ocr_recog/ppocr_keys_v1.txt'
378
+ args.use_fp16 = self.use_fp16
379
  self.cn_recognizer = TextRecognizer(args, self.text_predictor)
380
  for param in self.text_predictor.parameters():
381
  param.requires_grad = False
 
440
  diffusion_model = self.model.diffusion_model
441
  _cond = torch.cat(cond['c_crossattn'], 1)
442
  _hint = torch.cat(cond['c_concat'], 1)
443
+ if self.use_fp16:
444
+ x_noisy = x_noisy.half()
445
  control = self.control_model(x=x_noisy, timesteps=t, context=_cond, hint=_hint, text_info=cond['text_info'])
446
  control = [c * scale for c, scale in zip(control, self.control_scales)]
447
  eps = diffusion_model(x=x_noisy, timesteps=t, context=_cond, control=control, only_mid_control=self.only_mid_control)
cldm/model.py CHANGED
@@ -21,10 +21,14 @@ def load_state_dict(ckpt_path, location='cpu'):
21
  return state_dict
22
 
23
 
24
- def create_model(config_path, cond_stage_path=None):
25
  config = OmegaConf.load(config_path)
26
  if cond_stage_path:
27
  config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
 
 
 
 
28
  model = instantiate_from_config(config.model).cpu()
29
  print(f'Loaded model config from [{config_path}]')
30
  return model
 
21
  return state_dict
22
 
23
 
24
+ def create_model(config_path, cond_stage_path=None, use_fp16=False):
25
  config = OmegaConf.load(config_path)
26
  if cond_stage_path:
27
  config.model.params.cond_stage_config.params.version = cond_stage_path # use pre-downloaded ckpts, in case blocked
28
+ if use_fp16:
29
+ config.model.params.use_fp16 = True
30
+ config.model.params.control_stage_config.params.use_fp16 = True
31
+ config.model.params.unet_config.params.use_fp16 = True
32
  model = instantiate_from_config(config.model).cpu()
33
  print(f'Loaded model config from [{config_path}]')
34
  return model
cldm/recognizer.py CHANGED
@@ -132,6 +132,7 @@ class TextRecognizer(object):
132
  self.chars = self.get_char_dict(args.rec_char_dict_path)
133
  self.char2id = {x: i for i, x in enumerate(self.chars)}
134
  self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
 
135
 
136
  # img: CHW
137
  def resize_norm_img(self, img, max_wh_ratio):
@@ -188,6 +189,8 @@ class TextRecognizer(object):
188
  # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
189
  for ino in range(beg_img_no, end_img_no):
190
  norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
 
 
191
  norm_img = norm_img.unsqueeze(0)
192
  norm_img_batch.append(norm_img)
193
  norm_img_batch = torch.cat(norm_img_batch, dim=0)
 
132
  self.chars = self.get_char_dict(args.rec_char_dict_path)
133
  self.char2id = {x: i for i, x in enumerate(self.chars)}
134
  self.is_onnx = not isinstance(self.predictor, torch.nn.Module)
135
+ self.use_fp16 = args.use_fp16
136
 
137
  # img: CHW
138
  def resize_norm_img(self, img, max_wh_ratio):
 
189
  # max_wh_ratio = max(max_wh_ratio, wh_ratio) # comment to not use different ratio
190
  for ino in range(beg_img_no, end_img_no):
191
  norm_img = self.resize_norm_img(img_list[indices[ino]], max_wh_ratio)
192
+ if self.use_fp16:
193
+ norm_img = norm_img.half()
194
  norm_img = norm_img.unsqueeze(0)
195
  norm_img_batch.append(norm_img)
196
  norm_img_batch = torch.cat(norm_img_batch, dim=0)
ldm/modules/diffusionmodules/openaimodel.py CHANGED
@@ -510,7 +510,7 @@ class UNetModel(nn.Module):
510
  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
511
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
512
  f"attention will still not be set.")
513
-
514
  self.attention_resolutions = attention_resolutions
515
  self.dropout = dropout
516
  self.channel_mult = channel_mult
 
510
  f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
511
  f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
512
  f"attention will still not be set.")
513
+ self.use_fp16 = use_fp16
514
  self.attention_resolutions = attention_resolutions
515
  self.dropout = dropout
516
  self.channel_mult = channel_mult
ldm/modules/diffusionmodules/util.py CHANGED
@@ -216,7 +216,8 @@ class SiLU(nn.Module):
216
 
217
  class GroupNorm32(nn.GroupNorm):
218
  def forward(self, x):
219
- return super().forward(x.float()).type(x.dtype)
 
220
 
221
  def conv_nd(dims, *args, **kwargs):
222
  """
 
216
 
217
  class GroupNorm32(nn.GroupNorm):
218
  def forward(self, x):
219
+ # return super().forward(x.float()).type(x.dtype)
220
+ return super().forward(x).type(x.dtype)
221
 
222
  def conv_nd(dims, *args, **kwargs):
223
  """