3v324v23 commited on
Commit
d5cd13d
1 Parent(s): 3c9b64a
.gitignore CHANGED
@@ -7,3 +7,5 @@ log/
7
  log
8
  pretrained/
9
  pretrained
 
 
 
7
  log
8
  pretrained/
9
  pretrained
10
+ gradio_cached_examples/
11
+ gradio_cached_examples
app.py CHANGED
@@ -252,10 +252,6 @@ class vd_inference(object):
252
  assert False, 'Model type not supported'
253
  net = get_model()(cfgm)
254
 
255
- if self.which == 'v1.0':
256
- sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
257
- net.load_state_dict(sd, strict=False)
258
-
259
  if fp16:
260
  highlight_print('Running in FP16')
261
  if self.which == 'v1.0':
@@ -266,6 +262,20 @@ class vd_inference(object):
266
  else:
267
  self.dtype = torch.float32
268
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  self.use_cuda = torch.cuda.is_available()
270
  if self.use_cuda:
271
  net.to('cuda')
@@ -855,9 +865,11 @@ def tcg_interface(with_example=False):
855
  cache_examples=cache_examples, )
856
 
857
  gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
858
- '<img src="file/assets/demo/misc/mask_inst1.gif" style="float:left;max-width:450px;">'+
859
- '<img src="file/assets/demo/misc/mask_inst2.gif" style="float:left;max-width:450px;">'+
860
- '<img src="file/assets/demo/misc/mask_inst3.gif" style="float:left;max-width:450px;">',)
 
 
861
 
862
  def mcg_interface(with_example=False):
863
  num_img_input = 4
@@ -917,9 +929,11 @@ def mcg_interface(with_example=False):
917
  cache_examples=cache_examples, )
918
 
919
  gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
920
- '<img src="file/assets/demo/misc/mask_inst1.gif" style="float:left;max-width:450px;">'+
921
- '<img src="file/assets/demo/misc/mask_inst2.gif" style="float:left;max-width:450px;">'+
922
- '<img src="file/assets/demo/misc/mask_inst3.gif" style="float:left;max-width:450px;">',)
 
 
923
 
924
  ###########
925
  # Example #
@@ -1017,6 +1031,21 @@ css = """
1017
  margin: 0rem;
1018
  color: #6B7280;
1019
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1020
  """
1021
 
1022
  if True:
@@ -1025,7 +1054,7 @@ if True:
1025
  """
1026
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1027
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
1028
- Versatile Diffusion{}
1029
  </h1>
1030
  <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
1031
  We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
@@ -1041,8 +1070,7 @@ if True:
1041
  [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
1042
  </h3>
1043
  </div>
1044
- """.format(' '+vd_inference.which))
1045
- # .format('')) #
1046
 
1047
  with gr.Tab('Text-to-Image'):
1048
  t2i_interface(with_example=True)
@@ -1061,7 +1089,10 @@ if True:
1061
 
1062
  gr.HTML(
1063
  """
1064
- <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
 
 
 
1065
  <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1066
  <b>Caution</b>:
1067
  We would like the raise the awareness of users of this demo of its potential issues and concerns.
@@ -1077,7 +1108,7 @@ if True:
1077
  VD in this demo is meant only for research purposes.
1078
  </h3>
1079
  </div>
1080
- """)
1081
 
1082
  demo.launch(share=True)
1083
  # demo.launch(debug=True)
 
252
  assert False, 'Model type not supported'
253
  net = get_model()(cfgm)
254
 
 
 
 
 
255
  if fp16:
256
  highlight_print('Running in FP16')
257
  if self.which == 'v1.0':
 
262
  else:
263
  self.dtype = torch.float32
264
 
265
+ if self.which == 'v1.0':
266
+ # if fp16:
267
+ # sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
268
+ # else:
269
+ # sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
270
+ from huggingface_hub import hf_hub_download
271
+ if fp16:
272
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')
273
+ else:
274
+ temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth')
275
+ sd = torch.load(temppath, map_location='cpu')
276
+
277
+ net.load_state_dict(sd, strict=False)
278
+
279
  self.use_cuda = torch.cuda.is_available()
280
  if self.use_cuda:
281
  net.to('cuda')
 
865
  cache_examples=cache_examples, )
866
 
867
  gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
868
+ '<div id="maskinst">'+
869
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
870
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
871
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
872
+ '</div>')
873
 
874
  def mcg_interface(with_example=False):
875
  num_img_input = 4
 
929
  cache_examples=cache_examples, )
930
 
931
  gr.HTML('<br><p id=myinst>&nbsp How to add mask: Please see the following instructions.</p><br>'+
932
+ '<div id="maskinst">'+
933
+ '<img src="file/assets/demo/misc/mask_inst1.gif">'+
934
+ '<img src="file/assets/demo/misc/mask_inst2.gif">'+
935
+ '<img src="file/assets/demo/misc/mask_inst3.gif">'+
936
+ '</div>')
937
 
938
  ###########
939
  # Example #
 
1031
  margin: 0rem;
1032
  color: #6B7280;
1033
  }
1034
+ #maskinst {
1035
+ text-align: justify;
1036
+ min-width: 1200px;
1037
+ }
1038
+ #maskinst>img {
1039
+ min-width:399px;
1040
+ max-width:450px;
1041
+ vertical-align: top;
1042
+ display: inline-block;
1043
+ }
1044
+ #maskinst:after {
1045
+ content: "";
1046
+ width: 100%;
1047
+ display: inline-block;
1048
+ }
1049
  """
1050
 
1051
  if True:
 
1054
  """
1055
  <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
1056
  <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
1057
+ Versatile Diffusion
1058
  </h1>
1059
  <h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
1060
  We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
 
1070
  [<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
1071
  </h3>
1072
  </div>
1073
+ """)
 
1074
 
1075
  with gr.Tab('Text-to-Image'):
1076
  t2i_interface(with_example=True)
 
1089
 
1090
  gr.HTML(
1091
  """
1092
+ <div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
1093
+ <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1094
+ <b>Version</b>: {}
1095
+ </h3>
1096
  <h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
1097
  <b>Caution</b>:
1098
  We would like the raise the awareness of users of this demo of its potential issues and concerns.
 
1108
  VD in this demo is meant only for research purposes.
1109
  </h3>
1110
  </div>
1111
+ """.format(' '+vd_inference.which))
1112
 
1113
  demo.launch(share=True)
1114
  # demo.launch(debug=True)
configs/model/optimus.yaml CHANGED
@@ -92,7 +92,6 @@ optimus_gpt2_tokenizer:
92
  optimus_v1:
93
  super_cfg: optimus
94
  type: optimus_vae_next
95
- pth: pretrained/optimus-vae.pth
96
  args:
97
  encoder: MODEL(optimus_bert_encoder)
98
  decoder: MODEL(optimus_gpt2_decoder)
@@ -100,3 +99,5 @@ optimus_v1:
100
  tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
101
  args:
102
  latent_size: 768
 
 
 
92
  optimus_v1:
93
  super_cfg: optimus
94
  type: optimus_vae_next
 
95
  args:
96
  encoder: MODEL(optimus_bert_encoder)
97
  decoder: MODEL(optimus_gpt2_decoder)
 
99
  tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
100
  args:
101
  latent_size: 768
102
+ # pth: pretrained/optimus-vae.pth
103
+ hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/optimus-vae.pth']
lib/model_zoo/common/get_model.py CHANGED
@@ -8,27 +8,6 @@ from .utils import \
8
  get_total_param, get_total_param_sum, \
9
  get_unit
10
 
11
- # def load_state_dict(net, model_path):
12
- # if isinstance(net, dict):
13
- # for ni, neti in net.items():
14
- # paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
15
- # new_paras = neti.state_dict()
16
- # new_paras.update(paras)
17
- # neti.load_state_dict(new_paras)
18
- # else:
19
- # paras = torch.load(model_path, map_location=torch.device('cpu'))
20
- # new_paras = net.state_dict()
21
- # new_paras.update(paras)
22
- # net.load_state_dict(new_paras)
23
- # return
24
-
25
- # def save_state_dict(net, path):
26
- # if isinstance(net, (torch.nn.DataParallel,
27
- # torch.nn.parallel.DistributedDataParallel)):
28
- # torch.save(net.module.state_dict(), path)
29
- # else:
30
- # torch.save(net.state_dict(), path)
31
-
32
  def singleton(class_):
33
  instances = {}
34
  def getinstance(*args, **kwargs):
@@ -94,6 +73,14 @@ class get_model(object):
94
  net.load_state_dict(sd, strict=strict_sd)
95
  if verbose:
96
  print_log('Load pth from {}'.format(cfg.pth))
 
 
 
 
 
 
 
 
97
 
98
  # display param_num & param_sum
99
  if verbose:
 
8
  get_total_param, get_total_param_sum, \
9
  get_unit
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def singleton(class_):
12
  instances = {}
13
  def getinstance(*args, **kwargs):
 
73
  net.load_state_dict(sd, strict=strict_sd)
74
  if verbose:
75
  print_log('Load pth from {}'.format(cfg.pth))
76
+ elif 'hfm' in cfg:
77
+ from huggingface_hub import hf_hub_download
78
+ temppath = hf_hub_download(cfg.hfm[0], cfg.hfm[1])
79
+ sd = torch.load(temppath, map_location='cpu')
80
+ strict_sd = cfg.get('strict_sd', True)
81
+ net.load_state_dict(sd, strict=strict_sd)
82
+ if verbose:
83
+ print_log('Load hfm from {}/{}'.format(*cfg.hfm))
84
 
85
  # display param_num & param_sum
86
  if verbose:
requirements.txt CHANGED
@@ -12,5 +12,5 @@ torchmetrics==0.7.3
12
 
13
  einops==0.3.0
14
  omegaconf==2.1.1
15
- huggingface-hub==0.10.1
16
  gradio==3.17.1
 
12
 
13
  einops==0.3.0
14
  omegaconf==2.1.1
15
+ huggingface-hub==0.11.1
16
  gradio==3.17.1