Spaces:
Running
on
A10G
Running
on
A10G
new app
Browse files- .gitignore +2 -0
- app.py +46 -15
- configs/model/optimus.yaml +2 -1
- lib/model_zoo/common/get_model.py +8 -21
- requirements.txt +1 -1
.gitignore
CHANGED
@@ -7,3 +7,5 @@ log/
|
|
7 |
log
|
8 |
pretrained/
|
9 |
pretrained
|
|
|
|
|
|
7 |
log
|
8 |
pretrained/
|
9 |
pretrained
|
10 |
+
gradio_cached_examples/
|
11 |
+
gradio_cached_examples
|
app.py
CHANGED
@@ -252,10 +252,6 @@ class vd_inference(object):
|
|
252 |
assert False, 'Model type not supported'
|
253 |
net = get_model()(cfgm)
|
254 |
|
255 |
-
if self.which == 'v1.0':
|
256 |
-
sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
|
257 |
-
net.load_state_dict(sd, strict=False)
|
258 |
-
|
259 |
if fp16:
|
260 |
highlight_print('Running in FP16')
|
261 |
if self.which == 'v1.0':
|
@@ -266,6 +262,20 @@ class vd_inference(object):
|
|
266 |
else:
|
267 |
self.dtype = torch.float32
|
268 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
self.use_cuda = torch.cuda.is_available()
|
270 |
if self.use_cuda:
|
271 |
net.to('cuda')
|
@@ -855,9 +865,11 @@ def tcg_interface(with_example=False):
|
|
855 |
cache_examples=cache_examples, )
|
856 |
|
857 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
858 |
-
'<
|
859 |
-
|
860 |
-
|
|
|
|
|
861 |
|
862 |
def mcg_interface(with_example=False):
|
863 |
num_img_input = 4
|
@@ -917,9 +929,11 @@ def mcg_interface(with_example=False):
|
|
917 |
cache_examples=cache_examples, )
|
918 |
|
919 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
920 |
-
'<
|
921 |
-
|
922 |
-
|
|
|
|
|
923 |
|
924 |
###########
|
925 |
# Example #
|
@@ -1017,6 +1031,21 @@ css = """
|
|
1017 |
margin: 0rem;
|
1018 |
color: #6B7280;
|
1019 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1020 |
"""
|
1021 |
|
1022 |
if True:
|
@@ -1025,7 +1054,7 @@ if True:
|
|
1025 |
"""
|
1026 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
1027 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
1028 |
-
Versatile Diffusion
|
1029 |
</h1>
|
1030 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
1031 |
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
|
@@ -1041,8 +1070,7 @@ if True:
|
|
1041 |
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
|
1042 |
</h3>
|
1043 |
</div>
|
1044 |
-
"""
|
1045 |
-
# .format('')) #
|
1046 |
|
1047 |
with gr.Tab('Text-to-Image'):
|
1048 |
t2i_interface(with_example=True)
|
@@ -1061,7 +1089,10 @@ if True:
|
|
1061 |
|
1062 |
gr.HTML(
|
1063 |
"""
|
1064 |
-
<div style="text-align:
|
|
|
|
|
|
|
1065 |
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
1066 |
<b>Caution</b>:
|
1067 |
We would like the raise the awareness of users of this demo of its potential issues and concerns.
|
@@ -1077,7 +1108,7 @@ if True:
|
|
1077 |
VD in this demo is meant only for research purposes.
|
1078 |
</h3>
|
1079 |
</div>
|
1080 |
-
""")
|
1081 |
|
1082 |
demo.launch(share=True)
|
1083 |
# demo.launch(debug=True)
|
|
|
252 |
assert False, 'Model type not supported'
|
253 |
net = get_model()(cfgm)
|
254 |
|
|
|
|
|
|
|
|
|
255 |
if fp16:
|
256 |
highlight_print('Running in FP16')
|
257 |
if self.which == 'v1.0':
|
|
|
262 |
else:
|
263 |
self.dtype = torch.float32
|
264 |
|
265 |
+
if self.which == 'v1.0':
|
266 |
+
# if fp16:
|
267 |
+
# sd = torch.load('pretrained/vd-four-flow-v1-0-fp16.pth', map_location='cpu')
|
268 |
+
# else:
|
269 |
+
# sd = torch.load('pretrained/vd-four-flow-v1-0.pth', map_location='cpu')
|
270 |
+
from huggingface_hub import hf_hub_download
|
271 |
+
if fp16:
|
272 |
+
temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')
|
273 |
+
else:
|
274 |
+
temppath = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0.pth')
|
275 |
+
sd = torch.load(temppath, map_location='cpu')
|
276 |
+
|
277 |
+
net.load_state_dict(sd, strict=False)
|
278 |
+
|
279 |
self.use_cuda = torch.cuda.is_available()
|
280 |
if self.use_cuda:
|
281 |
net.to('cuda')
|
|
|
865 |
cache_examples=cache_examples, )
|
866 |
|
867 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
868 |
+
'<div id="maskinst">'+
|
869 |
+
'<img src="file/assets/demo/misc/mask_inst1.gif">'+
|
870 |
+
'<img src="file/assets/demo/misc/mask_inst2.gif">'+
|
871 |
+
'<img src="file/assets/demo/misc/mask_inst3.gif">'+
|
872 |
+
'</div>')
|
873 |
|
874 |
def mcg_interface(with_example=False):
|
875 |
num_img_input = 4
|
|
|
929 |
cache_examples=cache_examples, )
|
930 |
|
931 |
gr.HTML('<br><p id=myinst>  How to add mask: Please see the following instructions.</p><br>'+
|
932 |
+
'<div id="maskinst">'+
|
933 |
+
'<img src="file/assets/demo/misc/mask_inst1.gif">'+
|
934 |
+
'<img src="file/assets/demo/misc/mask_inst2.gif">'+
|
935 |
+
'<img src="file/assets/demo/misc/mask_inst3.gif">'+
|
936 |
+
'</div>')
|
937 |
|
938 |
###########
|
939 |
# Example #
|
|
|
1031 |
margin: 0rem;
|
1032 |
color: #6B7280;
|
1033 |
}
|
1034 |
+
#maskinst {
|
1035 |
+
text-align: justify;
|
1036 |
+
min-width: 1200px;
|
1037 |
+
}
|
1038 |
+
#maskinst>img {
|
1039 |
+
min-width:399px;
|
1040 |
+
max-width:450px;
|
1041 |
+
vertical-align: top;
|
1042 |
+
display: inline-block;
|
1043 |
+
}
|
1044 |
+
#maskinst:after {
|
1045 |
+
content: "";
|
1046 |
+
width: 100%;
|
1047 |
+
display: inline-block;
|
1048 |
+
}
|
1049 |
"""
|
1050 |
|
1051 |
if True:
|
|
|
1054 |
"""
|
1055 |
<div style="text-align: center; max-width: 1200px; margin: 20px auto;">
|
1056 |
<h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
|
1057 |
+
Versatile Diffusion
|
1058 |
</h1>
|
1059 |
<h2 style="font-weight: 450; font-size: 1rem; margin-top: 0.5rem; margin-bottom: 0.5rem">
|
1060 |
We built <b>Versatile Diffusion (VD), the first unified multi-flow multimodal diffusion framework</b>, as a step towards <b>Universal Generative AI</b>.
|
|
|
1070 |
[<a href="https://github.com/SHI-Labs/Versatile-Diffusion" style="color:blue;">GitHub</a>]
|
1071 |
</h3>
|
1072 |
</div>
|
1073 |
+
""")
|
|
|
1074 |
|
1075 |
with gr.Tab('Text-to-Image'):
|
1076 |
t2i_interface(with_example=True)
|
|
|
1089 |
|
1090 |
gr.HTML(
|
1091 |
"""
|
1092 |
+
<div style="text-align: justify; max-width: 1200px; margin: 20px auto;">
|
1093 |
+
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
1094 |
+
<b>Version</b>: {}
|
1095 |
+
</h3>
|
1096 |
<h3 style="font-weight: 450; font-size: 0.8rem; margin: 0rem">
|
1097 |
<b>Caution</b>:
|
1098 |
We would like the raise the awareness of users of this demo of its potential issues and concerns.
|
|
|
1108 |
VD in this demo is meant only for research purposes.
|
1109 |
</h3>
|
1110 |
</div>
|
1111 |
+
""".format(' '+vd_inference.which))
|
1112 |
|
1113 |
demo.launch(share=True)
|
1114 |
# demo.launch(debug=True)
|
configs/model/optimus.yaml
CHANGED
@@ -92,7 +92,6 @@ optimus_gpt2_tokenizer:
|
|
92 |
optimus_v1:
|
93 |
super_cfg: optimus
|
94 |
type: optimus_vae_next
|
95 |
-
pth: pretrained/optimus-vae.pth
|
96 |
args:
|
97 |
encoder: MODEL(optimus_bert_encoder)
|
98 |
decoder: MODEL(optimus_gpt2_decoder)
|
@@ -100,3 +99,5 @@ optimus_v1:
|
|
100 |
tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
|
101 |
args:
|
102 |
latent_size: 768
|
|
|
|
|
|
92 |
optimus_v1:
|
93 |
super_cfg: optimus
|
94 |
type: optimus_vae_next
|
|
|
95 |
args:
|
96 |
encoder: MODEL(optimus_bert_encoder)
|
97 |
decoder: MODEL(optimus_gpt2_decoder)
|
|
|
99 |
tokenizer_decoder: MODEL(optimus_gpt2_tokenizer)
|
100 |
args:
|
101 |
latent_size: 768
|
102 |
+
# pth: pretrained/optimus-vae.pth
|
103 |
+
hfm: ['shi-labs/versatile-diffusion-model', 'pretrained_pth/optimus-vae.pth']
|
lib/model_zoo/common/get_model.py
CHANGED
@@ -8,27 +8,6 @@ from .utils import \
|
|
8 |
get_total_param, get_total_param_sum, \
|
9 |
get_unit
|
10 |
|
11 |
-
# def load_state_dict(net, model_path):
|
12 |
-
# if isinstance(net, dict):
|
13 |
-
# for ni, neti in net.items():
|
14 |
-
# paras = torch.load(model_path[ni], map_location=torch.device('cpu'))
|
15 |
-
# new_paras = neti.state_dict()
|
16 |
-
# new_paras.update(paras)
|
17 |
-
# neti.load_state_dict(new_paras)
|
18 |
-
# else:
|
19 |
-
# paras = torch.load(model_path, map_location=torch.device('cpu'))
|
20 |
-
# new_paras = net.state_dict()
|
21 |
-
# new_paras.update(paras)
|
22 |
-
# net.load_state_dict(new_paras)
|
23 |
-
# return
|
24 |
-
|
25 |
-
# def save_state_dict(net, path):
|
26 |
-
# if isinstance(net, (torch.nn.DataParallel,
|
27 |
-
# torch.nn.parallel.DistributedDataParallel)):
|
28 |
-
# torch.save(net.module.state_dict(), path)
|
29 |
-
# else:
|
30 |
-
# torch.save(net.state_dict(), path)
|
31 |
-
|
32 |
def singleton(class_):
|
33 |
instances = {}
|
34 |
def getinstance(*args, **kwargs):
|
@@ -94,6 +73,14 @@ class get_model(object):
|
|
94 |
net.load_state_dict(sd, strict=strict_sd)
|
95 |
if verbose:
|
96 |
print_log('Load pth from {}'.format(cfg.pth))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
# display param_num & param_sum
|
99 |
if verbose:
|
|
|
8 |
get_total_param, get_total_param_sum, \
|
9 |
get_unit
|
10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def singleton(class_):
|
12 |
instances = {}
|
13 |
def getinstance(*args, **kwargs):
|
|
|
73 |
net.load_state_dict(sd, strict=strict_sd)
|
74 |
if verbose:
|
75 |
print_log('Load pth from {}'.format(cfg.pth))
|
76 |
+
elif 'hfm' in cfg:
|
77 |
+
from huggingface_hub import hf_hub_download
|
78 |
+
temppath = hf_hub_download(cfg.hfm[0], cfg.hfm[1])
|
79 |
+
sd = torch.load(temppath, map_location='cpu')
|
80 |
+
strict_sd = cfg.get('strict_sd', True)
|
81 |
+
net.load_state_dict(sd, strict=strict_sd)
|
82 |
+
if verbose:
|
83 |
+
print_log('Load hfm from {}/{}'.format(*cfg.hfm))
|
84 |
|
85 |
# display param_num & param_sum
|
86 |
if verbose:
|
requirements.txt
CHANGED
@@ -12,5 +12,5 @@ torchmetrics==0.7.3
|
|
12 |
|
13 |
einops==0.3.0
|
14 |
omegaconf==2.1.1
|
15 |
-
huggingface-hub==0.
|
16 |
gradio==3.17.1
|
|
|
12 |
|
13 |
einops==0.3.0
|
14 |
omegaconf==2.1.1
|
15 |
+
huggingface-hub==0.11.1
|
16 |
gradio==3.17.1
|