tubui
commited on
Commit
•
dfec228
0
Parent(s):
Duplicate from tubui/test
Browse files- .gitattributes +34 -0
- Dockerfile +10 -0
- Embed_Secret.py +264 -0
- README.md +12 -0
- cldm/ae.py +727 -0
- cldm/cldm.py +517 -0
- cldm/diffsteg.py +782 -0
- cldm/hack.py +113 -0
- cldm/logger.py +149 -0
- cldm/loss.py +78 -0
- cldm/loss_weight_scheduler.py +17 -0
- cldm/model.py +28 -0
- cldm/plms.py +1481 -0
- cldm/tmp.py +340 -0
- cldm/transformations.py +127 -0
- cldm/transformations2.py +415 -0
- cldm/utils.py +539 -0
- flae/models.py +325 -0
- flae/munit.py +576 -0
- flae/unet.py +123 -0
- ldm/modules/ema.py +80 -0
- ldm/util.py +197 -0
- pages/Extract_Secret.py +108 -0
- tools/__init__.py +3 -0
- tools/augment_imagenetc.py +155 -0
- tools/base_lmdb.py +588 -0
- tools/ecc.py +281 -0
- tools/eval_metrics.py +130 -0
- tools/fid.py +672 -0
- tools/fid_lmdb.py +683 -0
- tools/gradcam.py +152 -0
- tools/helpers.py +416 -0
- tools/hparams.py +743 -0
- tools/image_dataset.py +184 -0
- tools/image_dataset_generic.py +157 -0
- tools/image_tools.py +164 -0
- tools/imgcap_dataset.py +163 -0
- tools/sifid.py +246 -0
- tools/slack_bot.py +157 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
Dockerfile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM tuvbui/torchcpu:torch111
|
2 |
+
ADD cldm ./cldm
|
3 |
+
ADD flae ./flae
|
4 |
+
ADD ldm ./ldm
|
5 |
+
ADD tools ./tools
|
6 |
+
ADD pages ./pages
|
7 |
+
ADD Embed_Secret.py .
|
8 |
+
|
9 |
+
EXPOSE 7860
|
10 |
+
CMD streamlit run Embed_Secret.py --server.enableXsrfProtection=false --server.port 7860 -- --weight https://kahlan.cvssp.org/data/Flickr25K/tubui/stega/unet100b_croprs/epoch=000070-step=000219999.ckpt --config https://kahlan.cvssp.org/data/Flickr25K/tubui/stega/unet100b_croprs/-project.yaml
|
Embed_Secret.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
streamlit app demo
|
5 |
+
how to run:
|
6 |
+
streamlit run app.py --server.port 8501
|
7 |
+
|
8 |
+
@author: Tu Bui @surrey.ac.uk
|
9 |
+
"""
|
10 |
+
import os, sys, torch
|
11 |
+
import argparse
|
12 |
+
from pathlib import Path
|
13 |
+
import numpy as np
|
14 |
+
import pickle
|
15 |
+
import pytorch_lightning as pl
|
16 |
+
from torchvision import transforms
|
17 |
+
import argparse
|
18 |
+
from ldm.util import instantiate_from_config
|
19 |
+
from omegaconf import OmegaConf
|
20 |
+
from PIL import Image
|
21 |
+
from tools.augment_imagenetc import RandomImagenetC
|
22 |
+
from io import BytesIO
|
23 |
+
from tools.helpers import welcome_message
|
24 |
+
from tools.ecc import BCH, RSC
|
25 |
+
|
26 |
+
import streamlit as st
|
27 |
+
from streamlit.source_util import (
|
28 |
+
page_icon_and_name,
|
29 |
+
calc_md5,
|
30 |
+
get_pages,
|
31 |
+
_on_pages_changed
|
32 |
+
)
|
33 |
+
|
34 |
+
model_names = ['UNet']
|
35 |
+
|
36 |
+
|
37 |
+
def delete_page(main_script_path_str, page_name):
|
38 |
+
|
39 |
+
current_pages = get_pages(main_script_path_str)
|
40 |
+
|
41 |
+
for key, value in current_pages.items():
|
42 |
+
print(value['page_name'])
|
43 |
+
if value['page_name'] == page_name:
|
44 |
+
del current_pages[key]
|
45 |
+
break
|
46 |
+
else:
|
47 |
+
pass
|
48 |
+
_on_pages_changed.send()
|
49 |
+
|
50 |
+
|
51 |
+
def add_page(main_script_path_str, page_name):
|
52 |
+
|
53 |
+
pages = get_pages(main_script_path_str)
|
54 |
+
main_script_path = Path(main_script_path_str)
|
55 |
+
pages_dir = main_script_path.parent / "pages"
|
56 |
+
# st.write(list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")))
|
57 |
+
script_path = [f for f in list(pages_dir.glob("*.py"))+list(main_script_path.parent.glob("*.py")) if f.name.find(page_name) != -1][0]
|
58 |
+
script_path_str = str(script_path.resolve())
|
59 |
+
pi, pn = page_icon_and_name(script_path)
|
60 |
+
psh = calc_md5(script_path_str)
|
61 |
+
pages[psh] = {
|
62 |
+
"page_script_hash": psh,
|
63 |
+
"page_name": pn,
|
64 |
+
"icon": pi,
|
65 |
+
"script_path": script_path_str,
|
66 |
+
}
|
67 |
+
_on_pages_changed.send()
|
68 |
+
|
69 |
+
def unormalize(x):
|
70 |
+
# convert x in range [-1, 1], (B,C,H,W), tensor to [0, 255], uint8, numpy, (B,H,W,C)
|
71 |
+
x = torch.clamp((x + 1) * 127.5, 0, 255).permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8)
|
72 |
+
return x
|
73 |
+
|
74 |
+
def to_bytes(x, mime):
|
75 |
+
x = Image.fromarray(x)
|
76 |
+
buf = BytesIO()
|
77 |
+
f = "JPEG" if mime == 'image/jpeg' else "PNG"
|
78 |
+
x.save(buf, format=f)
|
79 |
+
byte_im = buf.getvalue()
|
80 |
+
return byte_im
|
81 |
+
|
82 |
+
|
83 |
+
def load_UNet(args):
|
84 |
+
print('args: ', args)
|
85 |
+
# # crop safe model
|
86 |
+
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/configs/-project.yaml'
|
87 |
+
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_tform2/checkpoints/epoch=000060-step=000189999.ckpt'
|
88 |
+
|
89 |
+
# # resized crop safe model
|
90 |
+
# config_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml'
|
91 |
+
# weight_file = '/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt'
|
92 |
+
|
93 |
+
config_file = args.config
|
94 |
+
weight_file = args.weight
|
95 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
96 |
+
if weight_file.startswith('http'): # download from url
|
97 |
+
weight_dir = Path('./weights')
|
98 |
+
weight_dir.mkdir(exist_ok=True)
|
99 |
+
weight_path = weight_dir / weight_file.split('/')[-1]
|
100 |
+
config_path = weight_dir / config_file.split('/')[-1]
|
101 |
+
if not weight_path.exists():
|
102 |
+
import wget
|
103 |
+
print(f'Downloading {weight_file}...')
|
104 |
+
with st.spinner("Downloading model... this may take awhile!"):
|
105 |
+
wget.download(weight_file, str(weight_path))
|
106 |
+
wget.download(config_file, str(config_path))
|
107 |
+
weight_file = str(weight_path)
|
108 |
+
config_file = str(config_path)
|
109 |
+
|
110 |
+
config = OmegaConf.load(config_file).model
|
111 |
+
secret_len = config.params.secret_len
|
112 |
+
print(f'Secret length: {secret_len}')
|
113 |
+
model = instantiate_from_config(config)
|
114 |
+
state_dict = torch.load(weight_file, map_location=torch.device('cpu'))
|
115 |
+
if 'global_step' in state_dict:
|
116 |
+
print(f'Global step: {state_dict["global_step"]}, epoch: {state_dict["epoch"]}')
|
117 |
+
|
118 |
+
if 'state_dict' in state_dict:
|
119 |
+
state_dict = state_dict['state_dict']
|
120 |
+
misses, ignores = model.load_state_dict(state_dict, strict=False)
|
121 |
+
print(f'Missed keys: {misses}\nIgnore keys: {ignores}')
|
122 |
+
model = model.to(device)
|
123 |
+
model.eval()
|
124 |
+
return model, secret_len
|
125 |
+
|
126 |
+
def embed_secret(model_name, model, cover, tform, secret):
|
127 |
+
if model_name == 'UNet':
|
128 |
+
w, h = cover.size
|
129 |
+
with torch.no_grad():
|
130 |
+
im = tform(cover).unsqueeze(0).to(model.device) # 1, 3, 256, 256
|
131 |
+
stego, _ = model(im, secret) # 1, 3, 256, 256
|
132 |
+
res = (stego.clamp(-1,1) - im) # (1,3,256,256) residual
|
133 |
+
res = torch.nn.functional.interpolate(res, (h,w), mode='bilinear')
|
134 |
+
res = res.permute(0,2,3,1).cpu().numpy() # (1,256,256,3)
|
135 |
+
stego_uint8 = np.clip(res[0] + np.array(cover)/127.5-1., -1,1)*127.5+127.5 # (256, 256, 3), ndarray, uint8
|
136 |
+
stego_uint8 = stego_uint8.astype(np.uint8)
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
return stego_uint8
|
140 |
+
|
141 |
+
def identity(x):
|
142 |
+
return x
|
143 |
+
|
144 |
+
def decode_secret(model_name, model, im, tform):
|
145 |
+
if model_name in ['RoSteALS', 'UNet']:
|
146 |
+
with torch.no_grad():
|
147 |
+
im = tform(im).unsqueeze(0).to(model.device) # 1, 3, 256, 256
|
148 |
+
secret_pred = (model.decoder(im) > 0).cpu().numpy() # 1, 100
|
149 |
+
else:
|
150 |
+
raise NotImplementedError
|
151 |
+
return secret_pred
|
152 |
+
|
153 |
+
|
154 |
+
@st.cache_resource
|
155 |
+
def load_model(model_name, _args):
|
156 |
+
if model_name == 'UNet':
|
157 |
+
tform_emb = transforms.Compose([
|
158 |
+
transforms.Resize((256,256)),
|
159 |
+
transforms.ToTensor(),
|
160 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
161 |
+
])
|
162 |
+
tform_det = transforms.Compose([
|
163 |
+
transforms.Resize((224,224)),
|
164 |
+
transforms.ToTensor(),
|
165 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
166 |
+
])
|
167 |
+
model, secret_len = load_UNet(_args)
|
168 |
+
else:
|
169 |
+
raise NotImplementedError
|
170 |
+
return model, tform_emb, tform_det, secret_len
|
171 |
+
|
172 |
+
|
173 |
+
@st.cache_resource
|
174 |
+
def load_ecc(ecc_name, secret_len):
|
175 |
+
if ecc_name == 'BCH':
|
176 |
+
if secret_len == 160:
|
177 |
+
ecc = BCH(285, 10, secret_len, verbose=True)
|
178 |
+
elif secret_len == 100:
|
179 |
+
ecc = BCH(137, 5, payload_len= secret_len, verbose=True)
|
180 |
+
elif ecc_name == 'RSC':
|
181 |
+
ecc = RSC(data_bytes=16, ecc_bytes=4, verbose=True)
|
182 |
+
return ecc
|
183 |
+
|
184 |
+
|
185 |
+
class Resize(object):
|
186 |
+
def __init__(self, size=None) -> None:
|
187 |
+
self.size = size
|
188 |
+
def __call__(self, x, size=None):
|
189 |
+
if isinstance(x, np.ndarray):
|
190 |
+
x = Image.fromarray(x)
|
191 |
+
new_size = size if size is not None else self.size
|
192 |
+
if min(x.size) > min(new_size): # downsample
|
193 |
+
x = x.resize(new_size, Image.LANCZOS)
|
194 |
+
else: # upsample
|
195 |
+
x = x.resize(new_size, Image.BILINEAR)
|
196 |
+
x = np.array(x)
|
197 |
+
return x
|
198 |
+
|
199 |
+
|
200 |
+
def parse_st_args():
|
201 |
+
# usage: streamlit run app.py -- --arg1 val1 --arg2 val2
|
202 |
+
parser = argparse.ArgumentParser()
|
203 |
+
parser.add_argument('--weight', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/checkpoints/epoch=000070-step=000219999.ckpt')
|
204 |
+
parser.add_argument('--config', default='/mnt/fast/nobackup/scratch4weeks/tb0035/projects/diffsteg/FLAE/simple_t2_croprs/configs/-project.yaml')
|
205 |
+
# parser.add_argument('--cpu', action='store_true')
|
206 |
+
args = parser.parse_args()
|
207 |
+
return args
|
208 |
+
|
209 |
+
|
210 |
+
def app(args):
|
211 |
+
# delete_page('Embed_Secret', 'Extract_Secret')
|
212 |
+
st.title('Watermarking Demo')
|
213 |
+
# setup model
|
214 |
+
model_name = st.selectbox("Choose the model", model_names)
|
215 |
+
model, tform_emb, tform_det, secret_len = load_model(model_name, args)
|
216 |
+
display_width = 300
|
217 |
+
# ecc
|
218 |
+
ecc = load_ecc('BCH', secret_len)
|
219 |
+
|
220 |
+
# setup st
|
221 |
+
st.subheader("Input")
|
222 |
+
image_file = st.file_uploader("Upload an image", type=["png","jpg","jpeg"])
|
223 |
+
if image_file is not None:
|
224 |
+
print('Image: ', image_file.name)
|
225 |
+
ext = image_file.name.split('.')[-1]
|
226 |
+
im = Image.open(image_file).convert('RGB')
|
227 |
+
size0 = im.size
|
228 |
+
st.image(im, width=display_width)
|
229 |
+
secret_text = st.text_input(f'Input the secret (max {ecc.data_len} chars)', 'A secret')
|
230 |
+
assert len(secret_text) <= ecc.data_len
|
231 |
+
|
232 |
+
# embed
|
233 |
+
st.subheader("Embed results")
|
234 |
+
status = st.empty()
|
235 |
+
prep = transforms.Compose([
|
236 |
+
transforms.Resize((256,256)),
|
237 |
+
transforms.CenterCrop((224,224))
|
238 |
+
])
|
239 |
+
if image_file is not None and secret_text is not None:
|
240 |
+
secret = ecc.encode_text([secret_text]) # (1, len)
|
241 |
+
secret = torch.from_numpy(secret).float().to(model.device)
|
242 |
+
# im = tform(im).unsqueeze(0).cuda() # (1,3,H,W)
|
243 |
+
stego = embed_secret(model_name, model, im, tform_emb, secret)
|
244 |
+
st.image(stego, width=display_width)
|
245 |
+
|
246 |
+
# download button
|
247 |
+
mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
|
248 |
+
stego_bytes = to_bytes(stego, mime)
|
249 |
+
st.download_button(label='Download image', data=stego_bytes, file_name=f'stego.{ext}', mime=mime)
|
250 |
+
|
251 |
+
# verify secret
|
252 |
+
stego_processed = prep(Image.fromarray(stego))
|
253 |
+
secret_pred = decode_secret(model_name, model, stego_processed, tform_det)
|
254 |
+
bit_acc = (secret_pred == secret.cpu().numpy()).mean()
|
255 |
+
secret_pred = ecc.decode_text(secret_pred)[0]
|
256 |
+
status.markdown('**Secret recovery check:** ' + secret_pred, unsafe_allow_html=True)
|
257 |
+
status.markdown('**Bit accuracy:** ' + str(bit_acc), unsafe_allow_html=True)
|
258 |
+
|
259 |
+
if __name__ == '__main__':
|
260 |
+
args = parse_st_args()
|
261 |
+
app(args)
|
262 |
+
|
263 |
+
|
264 |
+
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: Test
|
3 |
+
emoji: 🐠
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: cc-by-nc-sa-4.0
|
9 |
+
duplicated_from: tubui/test
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
cldm/ae.py
ADDED
@@ -0,0 +1,727 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import einops
|
3 |
+
import torch
|
4 |
+
import torch as th
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as thf
|
7 |
+
import pytorch_lightning as pl
|
8 |
+
import torchvision
|
9 |
+
from copy import deepcopy
|
10 |
+
from ldm.modules.diffusionmodules.util import (
|
11 |
+
conv_nd,
|
12 |
+
linear,
|
13 |
+
zero_module,
|
14 |
+
timestep_embedding,
|
15 |
+
)
|
16 |
+
from contextlib import contextmanager, nullcontext
|
17 |
+
from einops import rearrange, repeat
|
18 |
+
from torchvision.utils import make_grid
|
19 |
+
from ldm.modules.attention import SpatialTransformer
|
20 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
21 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
22 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
|
23 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
24 |
+
from ldm.modules.ema import LitEma
|
25 |
+
from ldm.modules.distributions.distributions import normal_kl, DiagonalGaussianDistribution
|
26 |
+
from ldm.modules.diffusionmodules.model import Encoder
|
27 |
+
import lpips
|
28 |
+
import kornia
|
29 |
+
from kornia import color
|
30 |
+
|
31 |
+
def disabled_train(self, mode=True):
|
32 |
+
"""Overwrite model.train with this function to make sure train/eval mode
|
33 |
+
does not change anymore."""
|
34 |
+
return self
|
35 |
+
|
36 |
+
class View(nn.Module):
|
37 |
+
def __init__(self, *shape):
|
38 |
+
super().__init__()
|
39 |
+
self.shape = shape
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
return x.view(*self.shape)
|
43 |
+
|
44 |
+
|
45 |
+
class SecretEncoder3(nn.Module):
|
46 |
+
def __init__(self, secret_len, base_res=16, resolution=64) -> None:
|
47 |
+
super().__init__()
|
48 |
+
log_resolution = int(np.log2(resolution))
|
49 |
+
log_base = int(np.log2(base_res))
|
50 |
+
self.secret_len = secret_len
|
51 |
+
self.secret_scaler = nn.Sequential(
|
52 |
+
nn.Linear(secret_len, base_res*base_res*3),
|
53 |
+
nn.SiLU(),
|
54 |
+
View(-1, 3, base_res, base_res),
|
55 |
+
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
|
56 |
+
zero_module(conv_nd(2, 3, 3, 3, padding=1))
|
57 |
+
) # secret len -> ch x res x res
|
58 |
+
|
59 |
+
def copy_encoder_weight(self, ae_model):
|
60 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
61 |
+
return None
|
62 |
+
|
63 |
+
def encode(self, x):
|
64 |
+
x = self.secret_scaler(x)
|
65 |
+
return x
|
66 |
+
|
67 |
+
def forward(self, x, c):
|
68 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
69 |
+
c = self.encode(c)
|
70 |
+
return c, None
|
71 |
+
|
72 |
+
|
73 |
+
class SecretEncoder4(nn.Module):
|
74 |
+
"""same as SecretEncoder3 but with ch as input"""
|
75 |
+
def __init__(self, secret_len, ch=3, base_res=16, resolution=64) -> None:
|
76 |
+
super().__init__()
|
77 |
+
log_resolution = int(np.log2(resolution))
|
78 |
+
log_base = int(np.log2(base_res))
|
79 |
+
self.secret_len = secret_len
|
80 |
+
self.secret_scaler = nn.Sequential(
|
81 |
+
nn.Linear(secret_len, base_res*base_res*ch),
|
82 |
+
nn.SiLU(),
|
83 |
+
View(-1, ch, base_res, base_res),
|
84 |
+
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
|
85 |
+
zero_module(conv_nd(2, ch, ch, 3, padding=1))
|
86 |
+
) # secret len -> ch x res x res
|
87 |
+
|
88 |
+
def copy_encoder_weight(self, ae_model):
|
89 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
90 |
+
return None
|
91 |
+
|
92 |
+
def encode(self, x):
|
93 |
+
x = self.secret_scaler(x)
|
94 |
+
return x
|
95 |
+
|
96 |
+
def forward(self, x, c):
|
97 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
98 |
+
c = self.encode(c)
|
99 |
+
return c, None
|
100 |
+
|
101 |
+
class SecretEncoder6(nn.Module):
|
102 |
+
"""join img emb with secret emb"""
|
103 |
+
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, emode='c3') -> None:
|
104 |
+
super().__init__()
|
105 |
+
assert emode in ['c3', 'c2', 'm3']
|
106 |
+
|
107 |
+
if emode == 'c3': # c3: concat c and x each has ch channels
|
108 |
+
secret_ch = ch
|
109 |
+
join_ch = 2*ch
|
110 |
+
elif emode == 'c2': # c2: concat c (2) and x ave (1)
|
111 |
+
secret_ch = 2
|
112 |
+
join_ch = ch
|
113 |
+
elif emode == 'm3': # m3: multiply c (ch) and x (ch)
|
114 |
+
secret_ch = ch
|
115 |
+
join_ch = ch
|
116 |
+
|
117 |
+
# m3: multiply c (ch) and x ave (1)
|
118 |
+
log_resolution = int(np.log2(resolution))
|
119 |
+
log_base = int(np.log2(base_res))
|
120 |
+
self.secret_len = secret_len
|
121 |
+
self.emode = emode
|
122 |
+
self.resolution = resolution
|
123 |
+
self.secret_scaler = nn.Sequential(
|
124 |
+
nn.Linear(secret_len, base_res*base_res*secret_ch),
|
125 |
+
nn.SiLU(),
|
126 |
+
View(-1, secret_ch, base_res, base_res),
|
127 |
+
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
|
128 |
+
) # secret len -> ch x res x res
|
129 |
+
self.join_encoder = nn.Sequential(
|
130 |
+
conv_nd(2, join_ch, join_ch, 3, padding=1),
|
131 |
+
nn.SiLU(),
|
132 |
+
conv_nd(2, join_ch, ch, 3, padding=1),
|
133 |
+
nn.SiLU(),
|
134 |
+
conv_nd(2, ch, ch, 3, padding=1),
|
135 |
+
nn.SiLU()
|
136 |
+
)
|
137 |
+
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
|
138 |
+
|
139 |
+
def copy_encoder_weight(self, ae_model):
|
140 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
141 |
+
return None
|
142 |
+
|
143 |
+
def encode(self, x):
|
144 |
+
x = self.secret_scaler(x)
|
145 |
+
return x
|
146 |
+
|
147 |
+
def forward(self, x, c):
|
148 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
149 |
+
c = self.encode(c)
|
150 |
+
if self.emode == 'c3':
|
151 |
+
x = torch.cat([x, c], dim=1)
|
152 |
+
elif self.emode == 'c2':
|
153 |
+
x = torch.cat([x.mean(dim=1, keepdim=True), c], dim=1)
|
154 |
+
elif self.emode == 'm3':
|
155 |
+
x = x * c
|
156 |
+
dx = self.join_encoder(x)
|
157 |
+
dx = self.out_layer(dx)
|
158 |
+
return dx, None
|
159 |
+
|
160 |
+
class SecretEncoder5(nn.Module):
|
161 |
+
"""same as SecretEncoder3 but with ch as input"""
|
162 |
+
def __init__(self, secret_len, ch=3, base_res=16, resolution=64, joint=False) -> None:
|
163 |
+
super().__init__()
|
164 |
+
log_resolution = int(np.log2(resolution))
|
165 |
+
log_base = int(np.log2(base_res))
|
166 |
+
self.secret_len = secret_len
|
167 |
+
self.joint = joint
|
168 |
+
self.resolution = resolution
|
169 |
+
self.secret_scaler = nn.Sequential(
|
170 |
+
nn.Linear(secret_len, base_res*base_res*ch),
|
171 |
+
nn.SiLU(),
|
172 |
+
View(-1, ch, base_res, base_res),
|
173 |
+
nn.Upsample(scale_factor=(2**(log_resolution-log_base), 2**(log_resolution-log_base))), # chx16x16 -> chx256x256
|
174 |
+
) # secret len -> ch x res x res
|
175 |
+
if joint:
|
176 |
+
self.join_encoder = nn.Sequential(
|
177 |
+
conv_nd(2, 2*ch, 2*ch, 3, padding=1),
|
178 |
+
nn.SiLU(),
|
179 |
+
conv_nd(2, 2*ch, ch, 3, padding=1),
|
180 |
+
nn.SiLU()
|
181 |
+
)
|
182 |
+
self.out_layer = zero_module(conv_nd(2, ch, ch, 3, padding=1))
|
183 |
+
|
184 |
+
def copy_encoder_weight(self, ae_model):
|
185 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
186 |
+
return None
|
187 |
+
|
188 |
+
def encode(self, x):
|
189 |
+
x = self.secret_scaler(x)
|
190 |
+
return x
|
191 |
+
|
192 |
+
def forward(self, x, c):
|
193 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
194 |
+
c = self.encode(c)
|
195 |
+
if self.joint:
|
196 |
+
x = thf.interpolate(x, size=(self.resolution, self.resolution), mode="bilinear", align_corners=False, antialias=True)
|
197 |
+
c = self.join_encoder(torch.cat([x, c], dim=1))
|
198 |
+
c = self.out_layer(c)
|
199 |
+
return c, None
|
200 |
+
|
201 |
+
|
202 |
+
class SecretEncoder2(nn.Module):
|
203 |
+
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
|
204 |
+
ignore_keys=[],
|
205 |
+
image_key="image",
|
206 |
+
colorize_nlabels=None,
|
207 |
+
monitor=None,
|
208 |
+
ema_decay=None,
|
209 |
+
learn_logvar=False) -> None:
|
210 |
+
super().__init__()
|
211 |
+
log_resolution = int(np.log2(ddconfig.resolution))
|
212 |
+
self.secret_len = secret_len
|
213 |
+
self.learn_logvar = learn_logvar
|
214 |
+
self.image_key = image_key
|
215 |
+
self.encoder = Encoder(**ddconfig)
|
216 |
+
self.encoder.conv_out = zero_module(self.encoder.conv_out)
|
217 |
+
self.embed_dim = embed_dim
|
218 |
+
|
219 |
+
if colorize_nlabels is not None:
|
220 |
+
assert type(colorize_nlabels)==int
|
221 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
222 |
+
|
223 |
+
if monitor is not None:
|
224 |
+
self.monitor = monitor
|
225 |
+
|
226 |
+
self.secret_scaler = nn.Sequential(
|
227 |
+
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
|
228 |
+
nn.SiLU(),
|
229 |
+
View(-1, ddconfig.out_ch, 32, 32),
|
230 |
+
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
|
231 |
+
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
232 |
+
) # secret len -> ch x res x res
|
233 |
+
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
|
234 |
+
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
235 |
+
|
236 |
+
self.use_ema = ema_decay is not None
|
237 |
+
if self.use_ema:
|
238 |
+
self.ema_decay = ema_decay
|
239 |
+
assert 0. < ema_decay < 1.
|
240 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
241 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
242 |
+
|
243 |
+
if ckpt_path is not None:
|
244 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
245 |
+
|
246 |
+
|
247 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
248 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
249 |
+
keys = list(sd.keys())
|
250 |
+
for k in keys:
|
251 |
+
for ik in ignore_keys:
|
252 |
+
if k.startswith(ik):
|
253 |
+
print("Deleting key {} from state_dict.".format(k))
|
254 |
+
del sd[k]
|
255 |
+
misses, ignores = self.load_state_dict(sd, strict=False)
|
256 |
+
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
|
257 |
+
|
258 |
+
def copy_encoder_weight(self, ae_model):
|
259 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
260 |
+
return None
|
261 |
+
self.encoder.load_state_dict(ae_model.encoder.state_dict())
|
262 |
+
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
|
263 |
+
|
264 |
+
@contextmanager
|
265 |
+
def ema_scope(self, context=None):
|
266 |
+
if self.use_ema:
|
267 |
+
self.model_ema.store(self.parameters())
|
268 |
+
self.model_ema.copy_to(self)
|
269 |
+
if context is not None:
|
270 |
+
print(f"{context}: Switched to EMA weights")
|
271 |
+
try:
|
272 |
+
yield None
|
273 |
+
finally:
|
274 |
+
if self.use_ema:
|
275 |
+
self.model_ema.restore(self.parameters())
|
276 |
+
if context is not None:
|
277 |
+
print(f"{context}: Restored training weights")
|
278 |
+
|
279 |
+
def on_train_batch_end(self, *args, **kwargs):
|
280 |
+
if self.use_ema:
|
281 |
+
self.model_ema(self)
|
282 |
+
|
283 |
+
def encode(self, x):
|
284 |
+
h = self.encoder(x)
|
285 |
+
posterior = h
|
286 |
+
return posterior
|
287 |
+
|
288 |
+
def forward(self, x, c):
|
289 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
290 |
+
c = self.secret_scaler(c)
|
291 |
+
x = torch.cat([x, c], dim=1)
|
292 |
+
z = self.encode(x)
|
293 |
+
# z = self.out_layer(z)
|
294 |
+
return z, None
|
295 |
+
|
296 |
+
|
297 |
+
class SecretEncoder7(nn.Module):
|
298 |
+
def __init__(self, secret_len, ddconfig, ckpt_path=None,
|
299 |
+
ignore_keys=[],embed_dim=3,
|
300 |
+
ema_decay=None) -> None:
|
301 |
+
super().__init__()
|
302 |
+
log_resolution = int(np.log2(ddconfig.resolution))
|
303 |
+
self.secret_len = secret_len
|
304 |
+
self.encoder = Encoder(**ddconfig)
|
305 |
+
# self.encoder.conv_out = zero_module(self.encoder.conv_out)
|
306 |
+
self.quant_conv = nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
307 |
+
|
308 |
+
self.secret_scaler = nn.Sequential(
|
309 |
+
nn.Linear(secret_len, 32*32*2),
|
310 |
+
nn.SiLU(),
|
311 |
+
View(-1, 2, 32, 32),
|
312 |
+
# nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
|
313 |
+
# zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
314 |
+
) # secret len -> ch x res x res
|
315 |
+
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
|
316 |
+
# self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
317 |
+
|
318 |
+
self.use_ema = ema_decay is not None
|
319 |
+
if self.use_ema:
|
320 |
+
self.ema_decay = ema_decay
|
321 |
+
assert 0. < ema_decay < 1.
|
322 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
323 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
324 |
+
|
325 |
+
if ckpt_path is not None:
|
326 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
327 |
+
|
328 |
+
|
329 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
330 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
331 |
+
keys = list(sd.keys())
|
332 |
+
for k in keys:
|
333 |
+
for ik in ignore_keys:
|
334 |
+
if k.startswith(ik):
|
335 |
+
print("Deleting key {} from state_dict.".format(k))
|
336 |
+
del sd[k]
|
337 |
+
misses, ignores = self.load_state_dict(sd, strict=False)
|
338 |
+
print(f"[SecretEncoder7] Restored from {path}, misses: {len(misses)}, ignores: {len(ignores)}. Do not worry as we are not using the decoder and the secret encoder is a novel module.")
|
339 |
+
|
340 |
+
def copy_encoder_weight(self, ae_model):
|
341 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
342 |
+
# return None
|
343 |
+
self.encoder.load_state_dict(deepcopy(ae_model.encoder.state_dict()))
|
344 |
+
self.quant_conv.load_state_dict(deepcopy(ae_model.quant_conv.state_dict()))
|
345 |
+
|
346 |
+
@contextmanager
|
347 |
+
def ema_scope(self, context=None):
|
348 |
+
if self.use_ema:
|
349 |
+
self.model_ema.store(self.parameters())
|
350 |
+
self.model_ema.copy_to(self)
|
351 |
+
if context is not None:
|
352 |
+
print(f"{context}: Switched to EMA weights")
|
353 |
+
try:
|
354 |
+
yield None
|
355 |
+
finally:
|
356 |
+
if self.use_ema:
|
357 |
+
self.model_ema.restore(self.parameters())
|
358 |
+
if context is not None:
|
359 |
+
print(f"{context}: Restored training weights")
|
360 |
+
|
361 |
+
def on_train_batch_end(self, *args, **kwargs):
|
362 |
+
if self.use_ema:
|
363 |
+
self.model_ema(self)
|
364 |
+
|
365 |
+
def encode(self, x):
|
366 |
+
h = self.encoder(x)
|
367 |
+
h = self.quant_conv(h)
|
368 |
+
return h
|
369 |
+
|
370 |
+
def forward(self, x, c):
|
371 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
372 |
+
c = self.secret_scaler(c) # [B, 2, 32, 32]
|
373 |
+
# c = thf.interpolate(c, size=x.shape[-2:], mode="bilinear", align_corners=False)
|
374 |
+
c = thf.interpolate(c, size=x.shape[-2:], mode="nearest")
|
375 |
+
x = 0.2125 * x[:,0,...] + 0.7154 *x[:,1,...] + 0.0721 * x[:,2,...]
|
376 |
+
x = torch.cat([x.unsqueeze(1), c], dim=1)
|
377 |
+
z = self.encode(x)
|
378 |
+
# z = self.out_layer(z)
|
379 |
+
return z, None
|
380 |
+
|
381 |
+
class SecretEncoder(nn.Module):
|
382 |
+
def __init__(self, secret_len, embed_dim, ddconfig, ckpt_path=None,
|
383 |
+
ignore_keys=[],
|
384 |
+
image_key="image",
|
385 |
+
colorize_nlabels=None,
|
386 |
+
monitor=None,
|
387 |
+
ema_decay=None,
|
388 |
+
learn_logvar=False) -> None:
|
389 |
+
super().__init__()
|
390 |
+
log_resolution = int(np.log2(ddconfig.resolution))
|
391 |
+
self.secret_len = secret_len
|
392 |
+
self.learn_logvar = learn_logvar
|
393 |
+
self.image_key = image_key
|
394 |
+
self.encoder = Encoder(**ddconfig)
|
395 |
+
assert ddconfig["double_z"]
|
396 |
+
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
397 |
+
self.embed_dim = embed_dim
|
398 |
+
|
399 |
+
if colorize_nlabels is not None:
|
400 |
+
assert type(colorize_nlabels)==int
|
401 |
+
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
402 |
+
|
403 |
+
if monitor is not None:
|
404 |
+
self.monitor = monitor
|
405 |
+
|
406 |
+
self.use_ema = ema_decay is not None
|
407 |
+
if self.use_ema:
|
408 |
+
self.ema_decay = ema_decay
|
409 |
+
assert 0. < ema_decay < 1.
|
410 |
+
self.model_ema = LitEma(self, decay=ema_decay)
|
411 |
+
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
412 |
+
|
413 |
+
if ckpt_path is not None:
|
414 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
415 |
+
|
416 |
+
self.secret_scaler = nn.Sequential(
|
417 |
+
nn.Linear(secret_len, 32*32*ddconfig.out_ch),
|
418 |
+
nn.SiLU(),
|
419 |
+
View(-1, ddconfig.out_ch, 32, 32),
|
420 |
+
nn.Upsample(scale_factor=(2**(log_resolution-5), 2**(log_resolution-5))), # chx16x16 -> chx256x256
|
421 |
+
zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
422 |
+
) # secret len -> ch x res x res
|
423 |
+
# out_resolution = ddconfig.resolution//(len(ddconfig.ch_mult)-1)
|
424 |
+
self.out_layer = zero_module(conv_nd(2, ddconfig.out_ch, ddconfig.out_ch, 3, padding=1))
|
425 |
+
|
426 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
427 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
428 |
+
keys = list(sd.keys())
|
429 |
+
for k in keys:
|
430 |
+
for ik in ignore_keys:
|
431 |
+
if k.startswith(ik):
|
432 |
+
print("Deleting key {} from state_dict.".format(k))
|
433 |
+
del sd[k]
|
434 |
+
misses, ignores = self.load_state_dict(sd, strict=False)
|
435 |
+
print(f"[SecretEncoder] Restored from {path}, misses: {misses}, ignores: {ignores}")
|
436 |
+
|
437 |
+
def copy_encoder_weight(self, ae_model):
|
438 |
+
# misses, ignores = self.load_state_dict(ae_state_dict, strict=False)
|
439 |
+
self.encoder.load_state_dict(ae_model.encoder.state_dict())
|
440 |
+
self.quant_conv.load_state_dict(ae_model.quant_conv.state_dict())
|
441 |
+
|
442 |
+
@contextmanager
|
443 |
+
def ema_scope(self, context=None):
|
444 |
+
if self.use_ema:
|
445 |
+
self.model_ema.store(self.parameters())
|
446 |
+
self.model_ema.copy_to(self)
|
447 |
+
if context is not None:
|
448 |
+
print(f"{context}: Switched to EMA weights")
|
449 |
+
try:
|
450 |
+
yield None
|
451 |
+
finally:
|
452 |
+
if self.use_ema:
|
453 |
+
self.model_ema.restore(self.parameters())
|
454 |
+
if context is not None:
|
455 |
+
print(f"{context}: Restored training weights")
|
456 |
+
|
457 |
+
def on_train_batch_end(self, *args, **kwargs):
|
458 |
+
if self.use_ema:
|
459 |
+
self.model_ema(self)
|
460 |
+
|
461 |
+
def encode(self, x):
|
462 |
+
h = self.encoder(x)
|
463 |
+
moments = self.quant_conv(h)
|
464 |
+
posterior = DiagonalGaussianDistribution(moments)
|
465 |
+
return posterior
|
466 |
+
|
467 |
+
def forward(self, x, c):
|
468 |
+
# x: [B, C, H, W], c: [B, secret_len]
|
469 |
+
c = self.secret_scaler(c)
|
470 |
+
x = x + c
|
471 |
+
posterior = self.encode(x)
|
472 |
+
z = posterior.sample()
|
473 |
+
z = self.out_layer(z)
|
474 |
+
return z, posterior
|
475 |
+
|
476 |
+
|
477 |
+
class ControlAE(pl.LightningModule):
|
478 |
+
def __init__(self,
|
479 |
+
first_stage_key,
|
480 |
+
first_stage_config,
|
481 |
+
control_key,
|
482 |
+
control_config,
|
483 |
+
decoder_config,
|
484 |
+
loss_config,
|
485 |
+
noise_config='__none__',
|
486 |
+
use_ema=False,
|
487 |
+
secret_warmup=False,
|
488 |
+
scale_factor=1.,
|
489 |
+
ckpt_path="__none__",
|
490 |
+
):
|
491 |
+
super().__init__()
|
492 |
+
self.scale_factor = scale_factor
|
493 |
+
self.control_key = control_key
|
494 |
+
self.first_stage_key = first_stage_key
|
495 |
+
self.ae = instantiate_from_config(first_stage_config)
|
496 |
+
self.control = instantiate_from_config(control_config)
|
497 |
+
self.decoder = instantiate_from_config(decoder_config)
|
498 |
+
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
|
499 |
+
if noise_config != '__none__':
|
500 |
+
print('Using noise')
|
501 |
+
self.noise = instantiate_from_config(noise_config)
|
502 |
+
# copy weights from first stage
|
503 |
+
self.control.copy_encoder_weight(self.ae)
|
504 |
+
# freeze first stage
|
505 |
+
self.ae.eval()
|
506 |
+
self.ae.train = disabled_train
|
507 |
+
for p in self.ae.parameters():
|
508 |
+
p.requires_grad = False
|
509 |
+
|
510 |
+
self.loss_layer = instantiate_from_config(loss_config)
|
511 |
+
|
512 |
+
# early training phase
|
513 |
+
# self.fixed_input = True
|
514 |
+
self.fixed_x = None
|
515 |
+
self.fixed_img = None
|
516 |
+
self.fixed_input_recon = None
|
517 |
+
self.fixed_control = None
|
518 |
+
self.register_buffer("fixed_input", torch.tensor(True))
|
519 |
+
|
520 |
+
# secret warmup
|
521 |
+
self.secret_warmup = secret_warmup
|
522 |
+
self.secret_baselen = 2
|
523 |
+
self.secret_len = control_config.params.secret_len
|
524 |
+
if self.secret_warmup:
|
525 |
+
assert self.secret_len == 2**(int(np.log2(self.secret_len)))
|
526 |
+
|
527 |
+
self.use_ema = use_ema
|
528 |
+
if self.use_ema:
|
529 |
+
print('Using EMA')
|
530 |
+
self.control_ema = LitEma(self.control)
|
531 |
+
self.decoder_ema = LitEma(self.decoder)
|
532 |
+
print(f"Keeping EMAs of {len(list(self.control_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
|
533 |
+
|
534 |
+
if ckpt_path != '__none__':
|
535 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=[])
|
536 |
+
|
537 |
+
def get_warmup_secret(self, old_secret):
|
538 |
+
# old_secret: [B, secret_len]
|
539 |
+
# new_secret: [B, secret_len]
|
540 |
+
if self.secret_warmup:
|
541 |
+
bsz = old_secret.shape[0]
|
542 |
+
nrepeats = self.secret_len // self.secret_baselen
|
543 |
+
new_secret = torch.zeros((bsz, self.secret_baselen), dtype=torch.float).random_(0, 2).repeat_interleave(nrepeats, dim=1)
|
544 |
+
return new_secret.to(old_secret.device)
|
545 |
+
else:
|
546 |
+
return old_secret
|
547 |
+
|
548 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
549 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
550 |
+
keys = list(sd.keys())
|
551 |
+
for k in keys:
|
552 |
+
for ik in ignore_keys:
|
553 |
+
if k.startswith(ik):
|
554 |
+
print("Deleting key {} from state_dict.".format(k))
|
555 |
+
del sd[k]
|
556 |
+
self.load_state_dict(sd, strict=False)
|
557 |
+
print(f"Restored from {path}")
|
558 |
+
|
559 |
+
@contextmanager
|
560 |
+
def ema_scope(self, context=None):
|
561 |
+
if self.use_ema:
|
562 |
+
self.control_ema.store(self.control.parameters())
|
563 |
+
self.decoder_ema.store(self.decoder.parameters())
|
564 |
+
self.control_ema.copy_to(self.control)
|
565 |
+
self.decoder_ema.copy_to(self.decoder)
|
566 |
+
if context is not None:
|
567 |
+
print(f"{context}: Switched to EMA weights")
|
568 |
+
try:
|
569 |
+
yield None
|
570 |
+
finally:
|
571 |
+
if self.use_ema:
|
572 |
+
self.control_ema.restore(self.control.parameters())
|
573 |
+
self.decoder_ema.restore(self.decoder.parameters())
|
574 |
+
if context is not None:
|
575 |
+
print(f"{context}: Restored training weights")
|
576 |
+
|
577 |
+
def on_train_batch_end(self, *args, **kwargs):
|
578 |
+
if self.use_ema:
|
579 |
+
self.control_ema(self.control)
|
580 |
+
self.decoder_ema(self.decoder)
|
581 |
+
|
582 |
+
def compute_loss(self, pred, target):
|
583 |
+
# return thf.mse_loss(pred, target, reduction="none").mean(dim=(1, 2, 3))
|
584 |
+
lpips_loss = self.lpips_loss(pred, target).mean(dim=[1,2,3])
|
585 |
+
pred_yuv = color.rgb_to_yuv((pred + 1) / 2)
|
586 |
+
target_yuv = color.rgb_to_yuv((target + 1) / 2)
|
587 |
+
yuv_loss = torch.mean((pred_yuv - target_yuv)**2, dim=[2,3])
|
588 |
+
yuv_loss = 1.5*torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
|
589 |
+
return lpips_loss + yuv_loss
|
590 |
+
|
591 |
+
def forward(self, x, image, c):
|
592 |
+
if self.control.__class__.__name__ == 'SecretEncoder6':
|
593 |
+
eps, posterior = self.control(x, c)
|
594 |
+
else:
|
595 |
+
eps, posterior = self.control(image, c)
|
596 |
+
return x + eps, posterior
|
597 |
+
|
598 |
+
@torch.no_grad()
|
599 |
+
def get_input(self, batch, return_first_stage=False, bs=None):
|
600 |
+
image = batch[self.first_stage_key]
|
601 |
+
control = batch[self.control_key]
|
602 |
+
control = self.get_warmup_secret(control)
|
603 |
+
if bs is not None:
|
604 |
+
image = image[:bs]
|
605 |
+
control = control[:bs]
|
606 |
+
else:
|
607 |
+
bs = image.shape[0]
|
608 |
+
# encode image 1st stage
|
609 |
+
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
610 |
+
x = self.encode_first_stage(image).detach()
|
611 |
+
image_rec = self.decode_first_stage(x).detach()
|
612 |
+
|
613 |
+
# check if using fixed input (early training phase)
|
614 |
+
# if self.training and self.fixed_input:
|
615 |
+
if self.fixed_input:
|
616 |
+
if self.fixed_x is None: # first iteration
|
617 |
+
print('[TRAINING] Warmup - using fixed input image for now!')
|
618 |
+
self.fixed_x = x.detach().clone()[:bs]
|
619 |
+
self.fixed_img = image.detach().clone()[:bs]
|
620 |
+
self.fixed_input_recon = image_rec.detach().clone()[:bs]
|
621 |
+
self.fixed_control = control.detach().clone()[:bs] # use for log_images with fixed_input option only
|
622 |
+
x, image, image_rec = self.fixed_x, self.fixed_img, self.fixed_input_recon
|
623 |
+
|
624 |
+
out = [x, control]
|
625 |
+
if return_first_stage:
|
626 |
+
out.extend([image, image_rec])
|
627 |
+
return out
|
628 |
+
|
629 |
+
def decode_first_stage(self, z):
|
630 |
+
z = 1./self.scale_factor * z
|
631 |
+
image_rec = self.ae.decode(z)
|
632 |
+
return image_rec
|
633 |
+
|
634 |
+
def encode_first_stage(self, image):
|
635 |
+
encoder_posterior = self.ae.encode(image)
|
636 |
+
if isinstance(encoder_posterior, DiagonalGaussianDistribution):
|
637 |
+
z = encoder_posterior.sample()
|
638 |
+
elif isinstance(encoder_posterior, torch.Tensor):
|
639 |
+
z = encoder_posterior
|
640 |
+
else:
|
641 |
+
raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented")
|
642 |
+
return self.scale_factor * z
|
643 |
+
|
644 |
+
def shared_step(self, batch):
|
645 |
+
x, c, img, _ = self.get_input(batch, return_first_stage=True)
|
646 |
+
# import pdb; pdb.set_trace()
|
647 |
+
x, posterior = self(x, img, c)
|
648 |
+
image_rec = self.decode_first_stage(x)
|
649 |
+
# resize
|
650 |
+
if img.shape[-1] > 256:
|
651 |
+
img = thf.interpolate(img, size=(256, 256), mode='bilinear', align_corners=False).detach()
|
652 |
+
image_rec = thf.interpolate(image_rec, size=(256, 256), mode='bilinear', align_corners=False)
|
653 |
+
if hasattr(self, 'noise') and self.noise.is_activated():
|
654 |
+
image_rec_noised = self.noise(image_rec, self.global_step, p=0.9)
|
655 |
+
else:
|
656 |
+
image_rec_noised = self.crop(image_rec) # center crop
|
657 |
+
image_rec_noised = torch.clamp(image_rec_noised, -1, 1)
|
658 |
+
pred = self.decoder(image_rec_noised)
|
659 |
+
|
660 |
+
loss, loss_dict = self.loss_layer(img, image_rec, posterior, c, pred, self.global_step)
|
661 |
+
bit_acc = loss_dict["bit_acc"]
|
662 |
+
|
663 |
+
bit_acc_ = bit_acc.item()
|
664 |
+
|
665 |
+
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
|
666 |
+
self.loss_layer.activate_ramp(self.global_step)
|
667 |
+
|
668 |
+
if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
|
669 |
+
if hasattr(self, 'noise') and (not self.noise.is_activated()):
|
670 |
+
self.noise.activate(self.global_step)
|
671 |
+
|
672 |
+
if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
|
673 |
+
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
|
674 |
+
self.fixed_input = ~self.fixed_input
|
675 |
+
return loss, loss_dict
|
676 |
+
|
677 |
+
def training_step(self, batch, batch_idx):
|
678 |
+
loss, loss_dict = self.shared_step(batch)
|
679 |
+
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
680 |
+
self.log_dict(loss_dict, prog_bar=True,
|
681 |
+
logger=True, on_step=True, on_epoch=True)
|
682 |
+
|
683 |
+
self.log("global_step", self.global_step,
|
684 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
685 |
+
# if self.use_scheduler:
|
686 |
+
# lr = self.optimizers().param_groups[0]['lr']
|
687 |
+
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
688 |
+
|
689 |
+
return loss
|
690 |
+
|
691 |
+
@torch.no_grad()
|
692 |
+
def validation_step(self, batch, batch_idx):
|
693 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
694 |
+
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
|
695 |
+
with self.ema_scope():
|
696 |
+
_, loss_dict_ema = self.shared_step(batch)
|
697 |
+
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
698 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
699 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
700 |
+
|
701 |
+
@torch.no_grad()
|
702 |
+
def log_images(self, batch, fixed_input=False, **kwargs):
|
703 |
+
log = dict()
|
704 |
+
if fixed_input and self.fixed_img is not None:
|
705 |
+
x, c, img, img_recon = self.fixed_x, self.fixed_control, self.fixed_img, self.fixed_input_recon
|
706 |
+
else:
|
707 |
+
x, c, img, img_recon = self.get_input(batch, return_first_stage=True)
|
708 |
+
x, _ = self(x, img, c)
|
709 |
+
image_out = self.decode_first_stage(x)
|
710 |
+
if hasattr(self, 'noise') and self.noise.is_activated():
|
711 |
+
img_noise = self.noise(image_out, self.global_step, p=1.0)
|
712 |
+
log['noised'] = img_noise
|
713 |
+
log['input'] = img
|
714 |
+
log['output'] = image_out
|
715 |
+
log['recon'] = img_recon
|
716 |
+
return log
|
717 |
+
|
718 |
+
def configure_optimizers(self):
|
719 |
+
lr = self.learning_rate
|
720 |
+
params = list(self.control.parameters()) + list(self.decoder.parameters())
|
721 |
+
optimizer = torch.optim.AdamW(params, lr=lr)
|
722 |
+
return optimizer
|
723 |
+
|
724 |
+
|
725 |
+
|
726 |
+
|
727 |
+
|
cldm/cldm.py
ADDED
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import einops
|
3 |
+
import torch
|
4 |
+
import torch as th
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision
|
7 |
+
from ldm.modules.diffusionmodules.util import (
|
8 |
+
conv_nd,
|
9 |
+
linear,
|
10 |
+
zero_module,
|
11 |
+
timestep_embedding,
|
12 |
+
)
|
13 |
+
|
14 |
+
from einops import rearrange, repeat
|
15 |
+
from torchvision.utils import make_grid
|
16 |
+
from ldm.modules.attention import SpatialTransformer
|
17 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
18 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
19 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config
|
20 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
21 |
+
|
22 |
+
|
23 |
+
class ControlledUnetModel(UNetModel):
|
24 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
25 |
+
hs = []
|
26 |
+
with torch.no_grad():
|
27 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
28 |
+
emb = self.time_embed(t_emb)
|
29 |
+
h = x.type(self.dtype)
|
30 |
+
for module in self.input_blocks:
|
31 |
+
h = module(h, emb, context)
|
32 |
+
hs.append(h)
|
33 |
+
h = self.middle_block(h, emb, context)
|
34 |
+
|
35 |
+
h += control.pop()
|
36 |
+
|
37 |
+
for i, module in enumerate(self.output_blocks):
|
38 |
+
if only_mid_control:
|
39 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
40 |
+
else:
|
41 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
42 |
+
h = module(h, emb, context)
|
43 |
+
|
44 |
+
h = h.type(x.dtype)
|
45 |
+
return self.out(h)
|
46 |
+
|
47 |
+
class View(nn.Module):
|
48 |
+
def __init__(self, *shape):
|
49 |
+
super().__init__()
|
50 |
+
self.shape = shape
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
return x.view(*self.shape)
|
54 |
+
|
55 |
+
class ControlNet(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
image_size,
|
59 |
+
in_channels,
|
60 |
+
model_channels,
|
61 |
+
hint_channels,
|
62 |
+
num_res_blocks,
|
63 |
+
attention_resolutions,
|
64 |
+
dropout=0,
|
65 |
+
channel_mult=(1, 2, 4, 8),
|
66 |
+
conv_resample=True,
|
67 |
+
dims=2,
|
68 |
+
use_checkpoint=False,
|
69 |
+
use_fp16=False,
|
70 |
+
num_heads=-1,
|
71 |
+
num_head_channels=-1,
|
72 |
+
num_heads_upsample=-1,
|
73 |
+
use_scale_shift_norm=False,
|
74 |
+
resblock_updown=False,
|
75 |
+
use_new_attention_order=False,
|
76 |
+
use_spatial_transformer=False, # custom transformer support
|
77 |
+
transformer_depth=1, # custom transformer support
|
78 |
+
context_dim=None, # custom transformer support
|
79 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
80 |
+
legacy=True,
|
81 |
+
disable_self_attentions=None,
|
82 |
+
num_attention_blocks=None,
|
83 |
+
disable_middle_self_attn=False,
|
84 |
+
use_linear_in_transformer=False,
|
85 |
+
secret_len = 0,
|
86 |
+
):
|
87 |
+
super().__init__()
|
88 |
+
if use_spatial_transformer:
|
89 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
90 |
+
|
91 |
+
if context_dim is not None:
|
92 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
93 |
+
from omegaconf.listconfig import ListConfig
|
94 |
+
if type(context_dim) == ListConfig:
|
95 |
+
context_dim = list(context_dim)
|
96 |
+
|
97 |
+
if num_heads_upsample == -1:
|
98 |
+
num_heads_upsample = num_heads
|
99 |
+
|
100 |
+
if num_heads == -1:
|
101 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
102 |
+
|
103 |
+
if num_head_channels == -1:
|
104 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
105 |
+
|
106 |
+
self.dims = dims
|
107 |
+
self.image_size = image_size
|
108 |
+
self.in_channels = in_channels
|
109 |
+
self.model_channels = model_channels
|
110 |
+
if isinstance(num_res_blocks, int):
|
111 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
112 |
+
else:
|
113 |
+
if len(num_res_blocks) != len(channel_mult):
|
114 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
115 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
116 |
+
self.num_res_blocks = num_res_blocks
|
117 |
+
if disable_self_attentions is not None:
|
118 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
119 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
120 |
+
if num_attention_blocks is not None:
|
121 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
122 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
123 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
124 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
125 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
126 |
+
f"attention will still not be set.")
|
127 |
+
|
128 |
+
self.attention_resolutions = attention_resolutions
|
129 |
+
self.dropout = dropout
|
130 |
+
self.channel_mult = channel_mult
|
131 |
+
self.conv_resample = conv_resample
|
132 |
+
self.use_checkpoint = use_checkpoint
|
133 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
134 |
+
self.num_heads = num_heads
|
135 |
+
self.num_head_channels = num_head_channels
|
136 |
+
self.num_heads_upsample = num_heads_upsample
|
137 |
+
self.predict_codebook_ids = n_embed is not None
|
138 |
+
|
139 |
+
time_embed_dim = model_channels * 4
|
140 |
+
self.time_embed = nn.Sequential(
|
141 |
+
linear(model_channels, time_embed_dim),
|
142 |
+
nn.SiLU(),
|
143 |
+
linear(time_embed_dim, time_embed_dim),
|
144 |
+
)
|
145 |
+
|
146 |
+
self.input_blocks = nn.ModuleList(
|
147 |
+
[
|
148 |
+
TimestepEmbedSequential(
|
149 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
150 |
+
)
|
151 |
+
]
|
152 |
+
)
|
153 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
154 |
+
self.secret_len = secret_len
|
155 |
+
if secret_len > 0:
|
156 |
+
log_resolution = int(np.log2(64))
|
157 |
+
self.input_hint_block = TimestepEmbedSequential(
|
158 |
+
nn.Linear(secret_len, 16*16*4),
|
159 |
+
nn.SiLU(),
|
160 |
+
View(-1, 4, 16, 16),
|
161 |
+
nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
|
162 |
+
conv_nd(dims, 4, 64, 3, padding=1),
|
163 |
+
nn.SiLU(),
|
164 |
+
conv_nd(dims, 64, 256, 3, padding=1),
|
165 |
+
nn.SiLU(),
|
166 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
167 |
+
)
|
168 |
+
else:
|
169 |
+
self.input_hint_block = TimestepEmbedSequential(
|
170 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
171 |
+
nn.SiLU(),
|
172 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
173 |
+
nn.SiLU(),
|
174 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
175 |
+
nn.SiLU(),
|
176 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
177 |
+
nn.SiLU(),
|
178 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
179 |
+
nn.SiLU(),
|
180 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
181 |
+
nn.SiLU(),
|
182 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
183 |
+
nn.SiLU(),
|
184 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
185 |
+
)
|
186 |
+
|
187 |
+
self._feature_size = model_channels
|
188 |
+
input_block_chans = [model_channels]
|
189 |
+
ch = model_channels
|
190 |
+
ds = 1
|
191 |
+
for level, mult in enumerate(channel_mult):
|
192 |
+
for nr in range(self.num_res_blocks[level]):
|
193 |
+
layers = [
|
194 |
+
ResBlock(
|
195 |
+
ch,
|
196 |
+
time_embed_dim,
|
197 |
+
dropout,
|
198 |
+
out_channels=mult * model_channels,
|
199 |
+
dims=dims,
|
200 |
+
use_checkpoint=use_checkpoint,
|
201 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
202 |
+
)
|
203 |
+
]
|
204 |
+
ch = mult * model_channels
|
205 |
+
if ds in attention_resolutions:
|
206 |
+
if num_head_channels == -1:
|
207 |
+
dim_head = ch // num_heads
|
208 |
+
else:
|
209 |
+
num_heads = ch // num_head_channels
|
210 |
+
dim_head = num_head_channels
|
211 |
+
if legacy:
|
212 |
+
#num_heads = 1
|
213 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
214 |
+
if exists(disable_self_attentions):
|
215 |
+
disabled_sa = disable_self_attentions[level]
|
216 |
+
else:
|
217 |
+
disabled_sa = False
|
218 |
+
|
219 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
220 |
+
layers.append(
|
221 |
+
AttentionBlock(
|
222 |
+
ch,
|
223 |
+
use_checkpoint=use_checkpoint,
|
224 |
+
num_heads=num_heads,
|
225 |
+
num_head_channels=dim_head,
|
226 |
+
use_new_attention_order=use_new_attention_order,
|
227 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
228 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
229 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
230 |
+
use_checkpoint=use_checkpoint
|
231 |
+
)
|
232 |
+
)
|
233 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
234 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
235 |
+
self._feature_size += ch
|
236 |
+
input_block_chans.append(ch)
|
237 |
+
if level != len(channel_mult) - 1:
|
238 |
+
out_ch = ch
|
239 |
+
self.input_blocks.append(
|
240 |
+
TimestepEmbedSequential(
|
241 |
+
ResBlock(
|
242 |
+
ch,
|
243 |
+
time_embed_dim,
|
244 |
+
dropout,
|
245 |
+
out_channels=out_ch,
|
246 |
+
dims=dims,
|
247 |
+
use_checkpoint=use_checkpoint,
|
248 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
249 |
+
down=True,
|
250 |
+
)
|
251 |
+
if resblock_updown
|
252 |
+
else Downsample(
|
253 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
254 |
+
)
|
255 |
+
)
|
256 |
+
)
|
257 |
+
ch = out_ch
|
258 |
+
input_block_chans.append(ch)
|
259 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
260 |
+
ds *= 2
|
261 |
+
self._feature_size += ch
|
262 |
+
|
263 |
+
if num_head_channels == -1:
|
264 |
+
dim_head = ch // num_heads
|
265 |
+
else:
|
266 |
+
num_heads = ch // num_head_channels
|
267 |
+
dim_head = num_head_channels
|
268 |
+
if legacy:
|
269 |
+
#num_heads = 1
|
270 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
271 |
+
self.middle_block = TimestepEmbedSequential(
|
272 |
+
ResBlock(
|
273 |
+
ch,
|
274 |
+
time_embed_dim,
|
275 |
+
dropout,
|
276 |
+
dims=dims,
|
277 |
+
use_checkpoint=use_checkpoint,
|
278 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
279 |
+
),
|
280 |
+
AttentionBlock(
|
281 |
+
ch,
|
282 |
+
use_checkpoint=use_checkpoint,
|
283 |
+
num_heads=num_heads,
|
284 |
+
num_head_channels=dim_head,
|
285 |
+
use_new_attention_order=use_new_attention_order,
|
286 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
287 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
288 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
289 |
+
use_checkpoint=use_checkpoint
|
290 |
+
),
|
291 |
+
ResBlock(
|
292 |
+
ch,
|
293 |
+
time_embed_dim,
|
294 |
+
dropout,
|
295 |
+
dims=dims,
|
296 |
+
use_checkpoint=use_checkpoint,
|
297 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
298 |
+
),
|
299 |
+
)
|
300 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
301 |
+
self._feature_size += ch
|
302 |
+
|
303 |
+
def make_zero_conv(self, channels):
|
304 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
305 |
+
|
306 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
307 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
308 |
+
emb = self.time_embed(t_emb)
|
309 |
+
# import pdb; pdb.set_trace()
|
310 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
311 |
+
|
312 |
+
outs = []
|
313 |
+
|
314 |
+
h = x.type(self.dtype)
|
315 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
316 |
+
if guided_hint is not None:
|
317 |
+
h = module(h, emb, context)
|
318 |
+
h += guided_hint
|
319 |
+
guided_hint = None
|
320 |
+
else:
|
321 |
+
h = module(h, emb, context)
|
322 |
+
outs.append(zero_conv(h, emb, context))
|
323 |
+
|
324 |
+
h = self.middle_block(h, emb, context)
|
325 |
+
outs.append(self.middle_block_out(h, emb, context))
|
326 |
+
|
327 |
+
return outs
|
328 |
+
|
329 |
+
|
330 |
+
class SecretDecoder(nn.Module):
|
331 |
+
def __init__(self, arch='CNN', act='ReLU', norm='none', resolution=256, in_channels=3, secret_len=100):
|
332 |
+
super().__init__()
|
333 |
+
self.resolution = resolution
|
334 |
+
self.arch = arch
|
335 |
+
print(f'SecretDecoder arch: {arch}')
|
336 |
+
def activation(name = 'ReLU'):
|
337 |
+
if name == 'ReLU':
|
338 |
+
return nn.ReLU()
|
339 |
+
elif name == 'LeakyReLU':
|
340 |
+
return nn.LeakyReLU()
|
341 |
+
elif name == 'SiLU':
|
342 |
+
return nn.SiLU()
|
343 |
+
|
344 |
+
def normalisation(name, n):
|
345 |
+
if name == 'none':
|
346 |
+
return nn.Identity()
|
347 |
+
elif name == 'BatchNorm2D':
|
348 |
+
return nn.BatchNorm2d(n)
|
349 |
+
elif name == 'BatchNorm1d':
|
350 |
+
return nn.BatchNorm1d(n)
|
351 |
+
elif name == 'LayerNorm':
|
352 |
+
return nn.LayerNorm(n)
|
353 |
+
|
354 |
+
if arch=='CNN':
|
355 |
+
self.decoder = nn.Sequential(
|
356 |
+
nn.Conv2d(in_channels, 32, (3, 3), 2, 1), # 128
|
357 |
+
activation(act),
|
358 |
+
nn.Conv2d(32, 32, 3, 1, 1),
|
359 |
+
activation(act),
|
360 |
+
nn.Conv2d(32, 64, 3, 2, 1), # 64
|
361 |
+
activation(act),
|
362 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
363 |
+
activation(act),
|
364 |
+
nn.Conv2d(64, 64, 3, 2, 1), # 32
|
365 |
+
activation(act),
|
366 |
+
nn.Conv2d(64, 128, 3, 2, 1), # 16
|
367 |
+
activation(act),
|
368 |
+
nn.Conv2d(128, 128, (3, 3), 2, 1), # 8
|
369 |
+
activation(act),
|
370 |
+
)
|
371 |
+
self.dense = nn.Sequential(
|
372 |
+
nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
|
373 |
+
activation(act),
|
374 |
+
nn.Linear(512, secret_len)
|
375 |
+
)
|
376 |
+
elif arch == 'resnet50':
|
377 |
+
self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
|
378 |
+
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
|
379 |
+
else:
|
380 |
+
raise NotImplementedError
|
381 |
+
|
382 |
+
def forward(self, image):
|
383 |
+
x = self.decoder(image)
|
384 |
+
if self.arch == 'CNN':
|
385 |
+
x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
|
386 |
+
x = self.dense(x)
|
387 |
+
return x
|
388 |
+
|
389 |
+
|
390 |
+
class ControlLDM(LatentDiffusion):
|
391 |
+
|
392 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, secret_decoder_config, *args, **kwargs):
|
393 |
+
super().__init__(*args, **kwargs)
|
394 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
395 |
+
self.control_key = control_key
|
396 |
+
self.only_mid_control = only_mid_control
|
397 |
+
if secret_decoder_config != 'none':
|
398 |
+
self.secret_decoder = instantiate_from_config(secret_decoder_config)
|
399 |
+
|
400 |
+
@torch.no_grad()
|
401 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
402 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
403 |
+
control = batch[self.control_key]
|
404 |
+
if bs is not None:
|
405 |
+
control = control[:bs]
|
406 |
+
control = control.to(self.device)
|
407 |
+
if self.control_key == 'hint':
|
408 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
409 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
410 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
411 |
+
|
412 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
413 |
+
assert isinstance(cond, dict)
|
414 |
+
diffusion_model = self.model.diffusion_model
|
415 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
416 |
+
cond_hint = torch.cat(cond['c_concat'], 1)
|
417 |
+
|
418 |
+
control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
|
419 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
420 |
+
|
421 |
+
return eps
|
422 |
+
|
423 |
+
@torch.no_grad()
|
424 |
+
def get_unconditional_conditioning(self, N):
|
425 |
+
return self.get_learned_conditioning([""] * N)
|
426 |
+
|
427 |
+
@torch.no_grad()
|
428 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
429 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
430 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
431 |
+
use_ema_scope=True,
|
432 |
+
**kwargs):
|
433 |
+
use_ddim = ddim_steps is not None
|
434 |
+
|
435 |
+
log = dict()
|
436 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
437 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
438 |
+
N = min(z.shape[0], N)
|
439 |
+
n_row = min(z.shape[0], n_row)
|
440 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
441 |
+
log["control"] = c_cat * 2.0 - 1.0
|
442 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
443 |
+
|
444 |
+
if plot_diffusion_rows:
|
445 |
+
# get diffusion row
|
446 |
+
diffusion_row = list()
|
447 |
+
z_start = z[:n_row]
|
448 |
+
for t in range(self.num_timesteps):
|
449 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
450 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
451 |
+
t = t.to(self.device).long()
|
452 |
+
noise = torch.randn_like(z_start)
|
453 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
454 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
455 |
+
|
456 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
457 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
458 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
459 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
460 |
+
log["diffusion_row"] = diffusion_grid
|
461 |
+
|
462 |
+
if sample:
|
463 |
+
# get denoise row
|
464 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
465 |
+
batch_size=N, ddim=use_ddim,
|
466 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
467 |
+
x_samples = self.decode_first_stage(samples)
|
468 |
+
log["samples"] = x_samples
|
469 |
+
if plot_denoise_rows:
|
470 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
471 |
+
log["denoise_row"] = denoise_grid
|
472 |
+
# import pudb; pudb.set_trace()
|
473 |
+
if unconditional_guidance_scale > 1.0:
|
474 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
475 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
476 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
477 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
478 |
+
batch_size=N, ddim=use_ddim,
|
479 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
480 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
481 |
+
unconditional_conditioning=uc_full,
|
482 |
+
)
|
483 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
484 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
485 |
+
|
486 |
+
return log
|
487 |
+
|
488 |
+
@torch.no_grad()
|
489 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
490 |
+
ddim_sampler = DDIMSampler(self)
|
491 |
+
# import pdb; pdb.set_trace()
|
492 |
+
# b, c, h, w = cond["c_concat"][0].shape
|
493 |
+
b, c, h, w = cond["c_concat"][0].shape[0], self.channels, self.image_size*8, self.image_size*8
|
494 |
+
shape = (self.channels, h // 8, w // 8)
|
495 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
496 |
+
return samples, intermediates
|
497 |
+
|
498 |
+
def configure_optimizers(self):
|
499 |
+
lr = self.learning_rate
|
500 |
+
params = list(self.control_model.parameters())
|
501 |
+
if not self.sd_locked:
|
502 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
503 |
+
params += list(self.model.diffusion_model.out.parameters())
|
504 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
505 |
+
return opt
|
506 |
+
|
507 |
+
def low_vram_shift(self, is_diffusing):
|
508 |
+
if is_diffusing:
|
509 |
+
self.model = self.model.cuda()
|
510 |
+
self.control_model = self.control_model.cuda()
|
511 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
512 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
513 |
+
else:
|
514 |
+
self.model = self.model.cpu()
|
515 |
+
self.control_model = self.control_model.cpu()
|
516 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
517 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
cldm/diffsteg.py
ADDED
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import einops
|
3 |
+
import torch
|
4 |
+
import torch as th
|
5 |
+
import torch.nn as nn
|
6 |
+
from torch.nn import functional as thf
|
7 |
+
import torchvision
|
8 |
+
from ldm.modules.diffusionmodules.util import (
|
9 |
+
conv_nd,
|
10 |
+
linear,
|
11 |
+
zero_module,
|
12 |
+
timestep_embedding,
|
13 |
+
)
|
14 |
+
|
15 |
+
from einops import rearrange, repeat
|
16 |
+
from torchvision.utils import make_grid
|
17 |
+
from ldm.modules.attention import SpatialTransformer
|
18 |
+
from ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock
|
19 |
+
from ldm.models.diffusion.ddpm import LatentDiffusion
|
20 |
+
from ldm.util import log_txt_as_img, exists, instantiate_from_config, default
|
21 |
+
from ldm.models.diffusion.ddim import DDIMSampler
|
22 |
+
|
23 |
+
|
24 |
+
# class CUNetModel(nn.Module):
|
25 |
+
# def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
26 |
+
# hs = []
|
27 |
+
# with torch.no_grad():
|
28 |
+
# t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
29 |
+
# emb = self.time_embed(t_emb)
|
30 |
+
|
31 |
+
# h = x.type(self.dtype)
|
32 |
+
# for module in self.input_blocks:
|
33 |
+
# h = module(h, emb, context)
|
34 |
+
# hs.append(h)
|
35 |
+
|
36 |
+
# h = self.middle_block(h, emb, context)
|
37 |
+
# h += control.pop(0)
|
38 |
+
# for module in self.output_blocks:
|
39 |
+
# if only_mid_control:
|
40 |
+
# h = th.cat([h, hs.pop()], dim=1)
|
41 |
+
# else:
|
42 |
+
# h = torch.cat([h, hs.pop() + control.pop(0)], dim=1)
|
43 |
+
# h = module(h, emb, context)
|
44 |
+
# h = h.type(x.dtype)
|
45 |
+
# return self.out(h)
|
46 |
+
|
47 |
+
class SecretNet(nn.Module):
|
48 |
+
def __init__(
|
49 |
+
self,
|
50 |
+
image_size,
|
51 |
+
in_channels,
|
52 |
+
model_channels,
|
53 |
+
hint_channels,
|
54 |
+
num_res_blocks,
|
55 |
+
attention_resolutions,
|
56 |
+
dropout=0,
|
57 |
+
channel_mult=(1, 2, 4, 8),
|
58 |
+
conv_resample=True,
|
59 |
+
dims=2,
|
60 |
+
use_checkpoint=False,
|
61 |
+
use_fp16=False,
|
62 |
+
num_heads=-1,
|
63 |
+
num_head_channels=-1,
|
64 |
+
num_heads_upsample=-1,
|
65 |
+
use_scale_shift_norm=False,
|
66 |
+
resblock_updown=False,
|
67 |
+
use_new_attention_order=False,
|
68 |
+
use_spatial_transformer=False, # custom transformer support
|
69 |
+
transformer_depth=1, # custom transformer support
|
70 |
+
context_dim=None, # custom transformer support
|
71 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
72 |
+
legacy=True,
|
73 |
+
disable_self_attentions=None,
|
74 |
+
num_attention_blocks=None,
|
75 |
+
disable_middle_self_attn=False,
|
76 |
+
use_linear_in_transformer=False,
|
77 |
+
secret_len = 0,
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
if use_spatial_transformer:
|
81 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
82 |
+
|
83 |
+
if context_dim is not None:
|
84 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
85 |
+
from omegaconf.listconfig import ListConfig
|
86 |
+
if type(context_dim) == ListConfig:
|
87 |
+
context_dim = list(context_dim)
|
88 |
+
|
89 |
+
if num_heads_upsample == -1:
|
90 |
+
num_heads_upsample = num_heads
|
91 |
+
|
92 |
+
if num_heads == -1:
|
93 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
94 |
+
|
95 |
+
if num_head_channels == -1:
|
96 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
97 |
+
|
98 |
+
self.dims = dims
|
99 |
+
self.image_size = image_size
|
100 |
+
self.in_channels = in_channels
|
101 |
+
self.model_channels = model_channels
|
102 |
+
if isinstance(num_res_blocks, int):
|
103 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
104 |
+
else:
|
105 |
+
if len(num_res_blocks) != len(channel_mult):
|
106 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
107 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
108 |
+
self.num_res_blocks = num_res_blocks
|
109 |
+
if disable_self_attentions is not None:
|
110 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
111 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
112 |
+
if num_attention_blocks is not None:
|
113 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
114 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
115 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
116 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
117 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
118 |
+
f"attention will still not be set.")
|
119 |
+
|
120 |
+
self.attention_resolutions = attention_resolutions
|
121 |
+
self.dropout = dropout
|
122 |
+
self.channel_mult = channel_mult
|
123 |
+
self.conv_resample = conv_resample
|
124 |
+
self.use_checkpoint = use_checkpoint
|
125 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
126 |
+
self.num_heads = num_heads
|
127 |
+
self.num_head_channels = num_head_channels
|
128 |
+
self.num_heads_upsample = num_heads_upsample
|
129 |
+
self.predict_codebook_ids = n_embed is not None
|
130 |
+
|
131 |
+
time_embed_dim = model_channels * 4
|
132 |
+
self.time_embed = nn.Sequential(
|
133 |
+
linear(model_channels, time_embed_dim),
|
134 |
+
nn.SiLU(),
|
135 |
+
linear(time_embed_dim, time_embed_dim),
|
136 |
+
)
|
137 |
+
|
138 |
+
# self.input_blocks = nn.ModuleList(
|
139 |
+
# [
|
140 |
+
# TimestepEmbedSequential(
|
141 |
+
# conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
142 |
+
# )
|
143 |
+
# ]
|
144 |
+
# )
|
145 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
146 |
+
self.secret_len = secret_len
|
147 |
+
if secret_len > 0: # TODO: update for dec
|
148 |
+
log_resolution = int(np.log2(64))
|
149 |
+
self.input_hint_block = TimestepEmbedSequential(
|
150 |
+
nn.Linear(secret_len, 16*16*4),
|
151 |
+
nn.SiLU(),
|
152 |
+
View(-1, 4, 16, 16),
|
153 |
+
nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
|
154 |
+
conv_nd(dims, 4, 64, 3, padding=1),
|
155 |
+
nn.SiLU(),
|
156 |
+
conv_nd(dims, 64, 256, 3, padding=1),
|
157 |
+
nn.SiLU(),
|
158 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
159 |
+
)
|
160 |
+
|
161 |
+
self._feature_size = model_channels
|
162 |
+
input_block_chans = [model_channels]
|
163 |
+
ch = model_channels
|
164 |
+
ds = 1
|
165 |
+
for level, mult in enumerate(channel_mult):
|
166 |
+
for nr in range(self.num_res_blocks[level]):
|
167 |
+
layers = []
|
168 |
+
ch = mult * model_channels
|
169 |
+
if ds in attention_resolutions:
|
170 |
+
if num_head_channels == -1:
|
171 |
+
dim_head = ch // num_heads
|
172 |
+
else:
|
173 |
+
num_heads = ch // num_head_channels
|
174 |
+
dim_head = num_head_channels
|
175 |
+
if legacy:
|
176 |
+
#num_heads = 1
|
177 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
178 |
+
if exists(disable_self_attentions):
|
179 |
+
disabled_sa = disable_self_attentions[level]
|
180 |
+
else:
|
181 |
+
disabled_sa = False
|
182 |
+
|
183 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
184 |
+
layers.append(0)
|
185 |
+
# self.input_blocks.append(TimestepEmbedSequential(*layers))
|
186 |
+
# self.zero_convs.append(self.make_zero_conv(ch))
|
187 |
+
self._feature_size += ch
|
188 |
+
input_block_chans.append(ch)
|
189 |
+
if level != len(channel_mult) - 1:
|
190 |
+
out_ch = ch
|
191 |
+
self.input_blocks.append(
|
192 |
+
0
|
193 |
+
)
|
194 |
+
ch = out_ch
|
195 |
+
input_block_chans.append(ch)
|
196 |
+
# self.zero_convs.append(self.make_zero_conv(ch))
|
197 |
+
ds *= 2
|
198 |
+
self._feature_size += ch
|
199 |
+
|
200 |
+
if num_head_channels == -1:
|
201 |
+
dim_head = ch // num_heads
|
202 |
+
else:
|
203 |
+
num_heads = ch // num_head_channels
|
204 |
+
dim_head = num_head_channels
|
205 |
+
if legacy:
|
206 |
+
#num_heads = 1
|
207 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
208 |
+
self.middle_block = TimestepEmbedSequential(
|
209 |
+
ResBlock(
|
210 |
+
ch,
|
211 |
+
time_embed_dim,
|
212 |
+
dropout,
|
213 |
+
dims=dims,
|
214 |
+
use_checkpoint=use_checkpoint,
|
215 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
216 |
+
),
|
217 |
+
AttentionBlock(
|
218 |
+
ch,
|
219 |
+
use_checkpoint=use_checkpoint,
|
220 |
+
num_heads=num_heads,
|
221 |
+
num_head_channels=dim_head,
|
222 |
+
use_new_attention_order=use_new_attention_order,
|
223 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
224 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
225 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
226 |
+
use_checkpoint=use_checkpoint
|
227 |
+
),
|
228 |
+
ResBlock(
|
229 |
+
ch,
|
230 |
+
time_embed_dim,
|
231 |
+
dropout,
|
232 |
+
dims=dims,
|
233 |
+
use_checkpoint=use_checkpoint,
|
234 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
235 |
+
),
|
236 |
+
)
|
237 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
238 |
+
self._feature_size += ch
|
239 |
+
|
240 |
+
def make_zero_conv(self, channels):
|
241 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
242 |
+
|
243 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
244 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
245 |
+
emb = self.time_embed(t_emb)
|
246 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
247 |
+
# import pdb; pdb.set_trace()
|
248 |
+
outs = []
|
249 |
+
|
250 |
+
h = x.type(self.dtype)
|
251 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
252 |
+
if guided_hint is not None:
|
253 |
+
h = module(h, emb, context)
|
254 |
+
h += guided_hint
|
255 |
+
guided_hint = None
|
256 |
+
else:
|
257 |
+
h = module(h, emb, context)
|
258 |
+
outs.append(zero_conv(h, emb, context))
|
259 |
+
|
260 |
+
h = self.middle_block(h, emb, context)
|
261 |
+
outs.append(self.middle_block_out(h, emb, context))
|
262 |
+
|
263 |
+
return outs
|
264 |
+
|
265 |
+
class ControlledUnetModel(UNetModel):
|
266 |
+
def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs):
|
267 |
+
hs = []
|
268 |
+
with torch.no_grad():
|
269 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
270 |
+
emb = self.time_embed(t_emb)
|
271 |
+
h = x.type(self.dtype)
|
272 |
+
for module in self.input_blocks:
|
273 |
+
h = module(h, emb, context)
|
274 |
+
hs.append(h)
|
275 |
+
h = self.middle_block(h, emb, context)
|
276 |
+
|
277 |
+
h += control.pop()
|
278 |
+
|
279 |
+
for i, module in enumerate(self.output_blocks):
|
280 |
+
if only_mid_control:
|
281 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
282 |
+
else:
|
283 |
+
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
284 |
+
h = module(h, emb, context)
|
285 |
+
|
286 |
+
h = h.type(x.dtype)
|
287 |
+
return self.out(h)
|
288 |
+
|
289 |
+
class View(nn.Module):
|
290 |
+
def __init__(self, *shape):
|
291 |
+
super().__init__()
|
292 |
+
self.shape = shape
|
293 |
+
|
294 |
+
def forward(self, x):
|
295 |
+
return x.view(*self.shape)
|
296 |
+
|
297 |
+
class ControlNet(nn.Module):
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
image_size,
|
301 |
+
in_channels,
|
302 |
+
model_channels,
|
303 |
+
hint_channels,
|
304 |
+
num_res_blocks,
|
305 |
+
attention_resolutions,
|
306 |
+
dropout=0,
|
307 |
+
channel_mult=(1, 2, 4, 8),
|
308 |
+
conv_resample=True,
|
309 |
+
dims=2,
|
310 |
+
use_checkpoint=False,
|
311 |
+
use_fp16=False,
|
312 |
+
num_heads=-1,
|
313 |
+
num_head_channels=-1,
|
314 |
+
num_heads_upsample=-1,
|
315 |
+
use_scale_shift_norm=False,
|
316 |
+
resblock_updown=False,
|
317 |
+
use_new_attention_order=False,
|
318 |
+
use_spatial_transformer=False, # custom transformer support
|
319 |
+
transformer_depth=1, # custom transformer support
|
320 |
+
context_dim=None, # custom transformer support
|
321 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
322 |
+
legacy=True,
|
323 |
+
disable_self_attentions=None,
|
324 |
+
num_attention_blocks=None,
|
325 |
+
disable_middle_self_attn=False,
|
326 |
+
use_linear_in_transformer=False,
|
327 |
+
secret_len = 0,
|
328 |
+
):
|
329 |
+
super().__init__()
|
330 |
+
if use_spatial_transformer:
|
331 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
332 |
+
|
333 |
+
if context_dim is not None:
|
334 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
335 |
+
from omegaconf.listconfig import ListConfig
|
336 |
+
if type(context_dim) == ListConfig:
|
337 |
+
context_dim = list(context_dim)
|
338 |
+
|
339 |
+
if num_heads_upsample == -1:
|
340 |
+
num_heads_upsample = num_heads
|
341 |
+
|
342 |
+
if num_heads == -1:
|
343 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
344 |
+
|
345 |
+
if num_head_channels == -1:
|
346 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
347 |
+
|
348 |
+
self.dims = dims
|
349 |
+
self.image_size = image_size
|
350 |
+
self.in_channels = in_channels
|
351 |
+
self.model_channels = model_channels
|
352 |
+
if isinstance(num_res_blocks, int):
|
353 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
354 |
+
else:
|
355 |
+
if len(num_res_blocks) != len(channel_mult):
|
356 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
357 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
358 |
+
self.num_res_blocks = num_res_blocks
|
359 |
+
if disable_self_attentions is not None:
|
360 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
361 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
362 |
+
if num_attention_blocks is not None:
|
363 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
364 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
365 |
+
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
366 |
+
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
367 |
+
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
368 |
+
f"attention will still not be set.")
|
369 |
+
|
370 |
+
self.attention_resolutions = attention_resolutions
|
371 |
+
self.dropout = dropout
|
372 |
+
self.channel_mult = channel_mult
|
373 |
+
self.conv_resample = conv_resample
|
374 |
+
self.use_checkpoint = use_checkpoint
|
375 |
+
self.dtype = th.float16 if use_fp16 else th.float32
|
376 |
+
self.num_heads = num_heads
|
377 |
+
self.num_head_channels = num_head_channels
|
378 |
+
self.num_heads_upsample = num_heads_upsample
|
379 |
+
self.predict_codebook_ids = n_embed is not None
|
380 |
+
|
381 |
+
time_embed_dim = model_channels * 4
|
382 |
+
self.time_embed = nn.Sequential(
|
383 |
+
linear(model_channels, time_embed_dim),
|
384 |
+
nn.SiLU(),
|
385 |
+
linear(time_embed_dim, time_embed_dim),
|
386 |
+
)
|
387 |
+
|
388 |
+
self.input_blocks = nn.ModuleList(
|
389 |
+
[
|
390 |
+
TimestepEmbedSequential(
|
391 |
+
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
392 |
+
)
|
393 |
+
]
|
394 |
+
)
|
395 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
396 |
+
self.secret_len = secret_len
|
397 |
+
if secret_len > 0:
|
398 |
+
log_resolution = int(np.log2(64))
|
399 |
+
self.input_hint_block = TimestepEmbedSequential(
|
400 |
+
nn.Linear(secret_len, 16*16*4),
|
401 |
+
nn.SiLU(),
|
402 |
+
View(-1, 4, 16, 16),
|
403 |
+
nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4))),
|
404 |
+
conv_nd(dims, 4, 64, 3, padding=1),
|
405 |
+
nn.SiLU(),
|
406 |
+
conv_nd(dims, 64, 256, 3, padding=1),
|
407 |
+
nn.SiLU(),
|
408 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
self.input_hint_block = TimestepEmbedSequential(
|
412 |
+
conv_nd(dims, hint_channels, 16, 3, padding=1),
|
413 |
+
nn.SiLU(),
|
414 |
+
conv_nd(dims, 16, 16, 3, padding=1),
|
415 |
+
nn.SiLU(),
|
416 |
+
conv_nd(dims, 16, 32, 3, padding=1, stride=2),
|
417 |
+
nn.SiLU(),
|
418 |
+
conv_nd(dims, 32, 32, 3, padding=1),
|
419 |
+
nn.SiLU(),
|
420 |
+
conv_nd(dims, 32, 96, 3, padding=1, stride=2),
|
421 |
+
nn.SiLU(),
|
422 |
+
conv_nd(dims, 96, 96, 3, padding=1),
|
423 |
+
nn.SiLU(),
|
424 |
+
conv_nd(dims, 96, 256, 3, padding=1, stride=2),
|
425 |
+
nn.SiLU(),
|
426 |
+
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1))
|
427 |
+
)
|
428 |
+
|
429 |
+
self._feature_size = model_channels
|
430 |
+
input_block_chans = [model_channels]
|
431 |
+
ch = model_channels
|
432 |
+
ds = 1
|
433 |
+
for level, mult in enumerate(channel_mult):
|
434 |
+
for nr in range(self.num_res_blocks[level]):
|
435 |
+
layers = [
|
436 |
+
ResBlock(
|
437 |
+
ch,
|
438 |
+
time_embed_dim,
|
439 |
+
dropout,
|
440 |
+
out_channels=mult * model_channels,
|
441 |
+
dims=dims,
|
442 |
+
use_checkpoint=use_checkpoint,
|
443 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
444 |
+
)
|
445 |
+
]
|
446 |
+
ch = mult * model_channels
|
447 |
+
if ds in attention_resolutions:
|
448 |
+
if num_head_channels == -1:
|
449 |
+
dim_head = ch // num_heads
|
450 |
+
else:
|
451 |
+
num_heads = ch // num_head_channels
|
452 |
+
dim_head = num_head_channels
|
453 |
+
if legacy:
|
454 |
+
#num_heads = 1
|
455 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
456 |
+
if exists(disable_self_attentions):
|
457 |
+
disabled_sa = disable_self_attentions[level]
|
458 |
+
else:
|
459 |
+
disabled_sa = False
|
460 |
+
|
461 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
462 |
+
layers.append(
|
463 |
+
AttentionBlock(
|
464 |
+
ch,
|
465 |
+
use_checkpoint=use_checkpoint,
|
466 |
+
num_heads=num_heads,
|
467 |
+
num_head_channels=dim_head,
|
468 |
+
use_new_attention_order=use_new_attention_order,
|
469 |
+
) if not use_spatial_transformer else SpatialTransformer(
|
470 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
471 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
472 |
+
use_checkpoint=use_checkpoint
|
473 |
+
)
|
474 |
+
)
|
475 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
476 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
477 |
+
self._feature_size += ch
|
478 |
+
input_block_chans.append(ch)
|
479 |
+
if level != len(channel_mult) - 1:
|
480 |
+
out_ch = ch
|
481 |
+
self.input_blocks.append(
|
482 |
+
TimestepEmbedSequential(
|
483 |
+
ResBlock(
|
484 |
+
ch,
|
485 |
+
time_embed_dim,
|
486 |
+
dropout,
|
487 |
+
out_channels=out_ch,
|
488 |
+
dims=dims,
|
489 |
+
use_checkpoint=use_checkpoint,
|
490 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
491 |
+
down=True,
|
492 |
+
)
|
493 |
+
if resblock_updown
|
494 |
+
else Downsample(
|
495 |
+
ch, conv_resample, dims=dims, out_channels=out_ch
|
496 |
+
)
|
497 |
+
)
|
498 |
+
)
|
499 |
+
ch = out_ch
|
500 |
+
input_block_chans.append(ch)
|
501 |
+
self.zero_convs.append(self.make_zero_conv(ch))
|
502 |
+
ds *= 2
|
503 |
+
self._feature_size += ch
|
504 |
+
|
505 |
+
if num_head_channels == -1:
|
506 |
+
dim_head = ch // num_heads
|
507 |
+
else:
|
508 |
+
num_heads = ch // num_head_channels
|
509 |
+
dim_head = num_head_channels
|
510 |
+
if legacy:
|
511 |
+
#num_heads = 1
|
512 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
513 |
+
self.middle_block = TimestepEmbedSequential(
|
514 |
+
ResBlock(
|
515 |
+
ch,
|
516 |
+
time_embed_dim,
|
517 |
+
dropout,
|
518 |
+
dims=dims,
|
519 |
+
use_checkpoint=use_checkpoint,
|
520 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
521 |
+
),
|
522 |
+
AttentionBlock(
|
523 |
+
ch,
|
524 |
+
use_checkpoint=use_checkpoint,
|
525 |
+
num_heads=num_heads,
|
526 |
+
num_head_channels=dim_head,
|
527 |
+
use_new_attention_order=use_new_attention_order,
|
528 |
+
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
529 |
+
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
530 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
531 |
+
use_checkpoint=use_checkpoint
|
532 |
+
),
|
533 |
+
ResBlock(
|
534 |
+
ch,
|
535 |
+
time_embed_dim,
|
536 |
+
dropout,
|
537 |
+
dims=dims,
|
538 |
+
use_checkpoint=use_checkpoint,
|
539 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
540 |
+
),
|
541 |
+
)
|
542 |
+
self.middle_block_out = self.make_zero_conv(ch)
|
543 |
+
self._feature_size += ch
|
544 |
+
|
545 |
+
def make_zero_conv(self, channels):
|
546 |
+
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
547 |
+
|
548 |
+
def forward(self, x, hint, timesteps, context, **kwargs):
|
549 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
550 |
+
emb = self.time_embed(t_emb)
|
551 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
552 |
+
# import pdb; pdb.set_trace()
|
553 |
+
outs = []
|
554 |
+
|
555 |
+
h = x.type(self.dtype)
|
556 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
557 |
+
if guided_hint is not None:
|
558 |
+
h = module(h, emb, context)
|
559 |
+
h += guided_hint
|
560 |
+
guided_hint = None
|
561 |
+
else:
|
562 |
+
h = module(h, emb, context)
|
563 |
+
outs.append(zero_conv(h, emb, context))
|
564 |
+
|
565 |
+
h = self.middle_block(h, emb, context)
|
566 |
+
outs.append(self.middle_block_out(h, emb, context))
|
567 |
+
|
568 |
+
return outs
|
569 |
+
|
570 |
+
|
571 |
+
class SecretDecoder(nn.Module):
|
572 |
+
def __init__(self, arch='resnet50', secret_len=100):
|
573 |
+
super().__init__()
|
574 |
+
self.arch = arch
|
575 |
+
print(f'SecretDecoder arch: {arch}')
|
576 |
+
self.resolution = 224
|
577 |
+
if arch == 'resnet50':
|
578 |
+
self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
|
579 |
+
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
|
580 |
+
elif arch == 'resnet18':
|
581 |
+
self.decoder = torchvision.models.resnet18(pretrained=True, progress=False)
|
582 |
+
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
|
583 |
+
else:
|
584 |
+
raise NotImplementedError
|
585 |
+
|
586 |
+
def forward(self, image):
|
587 |
+
if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution:
|
588 |
+
image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False)
|
589 |
+
x = self.decoder(image)
|
590 |
+
return x
|
591 |
+
|
592 |
+
|
593 |
+
class ControlLDM(LatentDiffusion):
|
594 |
+
|
595 |
+
def __init__(self, control_stage_config, control_key, only_mid_control, secret_decoder_config, *args, **kwargs):
|
596 |
+
super().__init__(*args, **kwargs)
|
597 |
+
self.control_model = instantiate_from_config(control_stage_config)
|
598 |
+
self.control_key = control_key
|
599 |
+
self.only_mid_control = only_mid_control
|
600 |
+
|
601 |
+
self.secret_decoder = None if secret_decoder_config == 'none' else instantiate_from_config(secret_decoder_config)
|
602 |
+
self.secret_loss_layer = nn.BCEWithLogitsLoss()
|
603 |
+
|
604 |
+
@torch.no_grad()
|
605 |
+
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
606 |
+
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
607 |
+
control = batch[self.control_key]
|
608 |
+
if bs is not None:
|
609 |
+
control = control[:bs]
|
610 |
+
control = control.to(self.device)
|
611 |
+
if self.control_key == 'hint':
|
612 |
+
control = einops.rearrange(control, 'b h w c -> b c h w')
|
613 |
+
control = control.to(memory_format=torch.contiguous_format).float()
|
614 |
+
return x, dict(c_crossattn=[c], c_concat=[control])
|
615 |
+
|
616 |
+
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
617 |
+
assert isinstance(cond, dict)
|
618 |
+
diffusion_model = self.model.diffusion_model
|
619 |
+
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
620 |
+
cond_hint = torch.cat(cond['c_concat'], 1)
|
621 |
+
|
622 |
+
control = self.control_model(x=x_noisy, hint=cond_hint, timesteps=t, context=cond_txt)
|
623 |
+
eps = diffusion_model(x=x_noisy, timesteps=t, context=cond_txt, control=control, only_mid_control=self.only_mid_control)
|
624 |
+
|
625 |
+
return eps
|
626 |
+
|
627 |
+
def p_losses(self, x_start, cond, t, noise=None):
|
628 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
629 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
630 |
+
model_output = self.apply_model(x_noisy, t, cond)
|
631 |
+
loss_dict = {}
|
632 |
+
prefix = 'train' if self.training else 'val'
|
633 |
+
|
634 |
+
if self.parameterization == "x0":
|
635 |
+
target = x_start
|
636 |
+
x_recon = model_output
|
637 |
+
elif self.parameterization == "eps":
|
638 |
+
target = noise
|
639 |
+
x_recon = self.predict_start_from_noise(x_noisy, t, noise=model_output)
|
640 |
+
elif self.parameterization == "v":
|
641 |
+
target = self.get_v(x_start, noise, t)
|
642 |
+
else:
|
643 |
+
raise NotImplementedError()
|
644 |
+
|
645 |
+
loss_simple = self.get_loss(model_output, target, mean=False).mean([1, 2, 3])
|
646 |
+
loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
|
647 |
+
|
648 |
+
logvar_t = self.logvar[t].to(self.device)
|
649 |
+
loss = loss_simple / torch.exp(logvar_t) + logvar_t
|
650 |
+
# loss = loss_simple / torch.exp(self.logvar) + self.logvar
|
651 |
+
if self.learn_logvar:
|
652 |
+
loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
|
653 |
+
loss_dict.update({'logvar': self.logvar.data.mean()})
|
654 |
+
|
655 |
+
loss = self.l_simple_weight * loss.mean()
|
656 |
+
|
657 |
+
loss_vlb = self.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
|
658 |
+
loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean()
|
659 |
+
loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
|
660 |
+
loss += (self.original_elbo_weight * loss_vlb)
|
661 |
+
# secret decode
|
662 |
+
if self.secret_decoder is not None:
|
663 |
+
simple_loss_weight = 0.1
|
664 |
+
x_recon = self.differentiable_decode_first_stage(x_recon)
|
665 |
+
secret_pred = self.secret_decoder(x_recon)
|
666 |
+
secret = cond['c_concat'][0]
|
667 |
+
loss_secret = self.secret_loss_layer(secret_pred, secret)
|
668 |
+
bit_acc = ((secret_pred.detach() > 0).float() == secret).float().mean()
|
669 |
+
loss_dict.update({f'{prefix}/bit_acc': bit_acc})
|
670 |
+
loss_dict.update({f'{prefix}/loss_secret': loss_secret})
|
671 |
+
loss = (loss*simple_loss_weight + loss_secret) / (simple_loss_weight + 1)
|
672 |
+
|
673 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
674 |
+
return loss, loss_dict
|
675 |
+
|
676 |
+
def differentiable_decode_first_stage(self, z, predict_cids=False, force_not_quantize=False):
|
677 |
+
if predict_cids:
|
678 |
+
if z.dim() == 4:
|
679 |
+
z = torch.argmax(z.exp(), dim=1).long()
|
680 |
+
z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None)
|
681 |
+
z = rearrange(z, 'b h w c -> b c h w').contiguous()
|
682 |
+
|
683 |
+
z = 1. / self.scale_factor * z
|
684 |
+
return self.first_stage_model.decode(z)
|
685 |
+
|
686 |
+
@torch.no_grad()
|
687 |
+
def get_unconditional_conditioning(self, N):
|
688 |
+
return self.get_learned_conditioning([""] * N)
|
689 |
+
|
690 |
+
@torch.no_grad()
|
691 |
+
def log_images(self, batch, N=4, n_row=2, sample=False, ddim_steps=50, ddim_eta=0.0, return_keys=None,
|
692 |
+
quantize_denoised=True, inpaint=True, plot_denoise_rows=False, plot_progressive_rows=True,
|
693 |
+
plot_diffusion_rows=False, unconditional_guidance_scale=9.0, unconditional_guidance_label=None,
|
694 |
+
use_ema_scope=True,
|
695 |
+
**kwargs):
|
696 |
+
use_ddim = ddim_steps is not None
|
697 |
+
|
698 |
+
log = dict()
|
699 |
+
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
700 |
+
c_cat, c = c["c_concat"][0][:N], c["c_crossattn"][0][:N]
|
701 |
+
N = min(z.shape[0], N)
|
702 |
+
n_row = min(z.shape[0], n_row)
|
703 |
+
log["reconstruction"] = self.decode_first_stage(z)
|
704 |
+
# log["control"] = c_cat * 2.0 - 1.0
|
705 |
+
log["conditioning"] = log_txt_as_img((512, 512), batch[self.cond_stage_key], size=16)
|
706 |
+
|
707 |
+
if plot_diffusion_rows:
|
708 |
+
# get diffusion row
|
709 |
+
diffusion_row = list()
|
710 |
+
z_start = z[:n_row]
|
711 |
+
for t in range(self.num_timesteps):
|
712 |
+
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
713 |
+
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
714 |
+
t = t.to(self.device).long()
|
715 |
+
noise = torch.randn_like(z_start)
|
716 |
+
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
717 |
+
diffusion_row.append(self.decode_first_stage(z_noisy))
|
718 |
+
|
719 |
+
diffusion_row = torch.stack(diffusion_row) # n_log_step, n_row, C, H, W
|
720 |
+
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
721 |
+
diffusion_grid = rearrange(diffusion_grid, 'b n c h w -> (b n) c h w')
|
722 |
+
diffusion_grid = make_grid(diffusion_grid, nrow=diffusion_row.shape[0])
|
723 |
+
log["diffusion_row"] = diffusion_grid
|
724 |
+
|
725 |
+
if sample:
|
726 |
+
# get denoise row
|
727 |
+
samples, z_denoise_row = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
728 |
+
batch_size=N, ddim=use_ddim,
|
729 |
+
ddim_steps=ddim_steps, eta=ddim_eta)
|
730 |
+
x_samples = self.decode_first_stage(samples)
|
731 |
+
log["samples"] = x_samples
|
732 |
+
if plot_denoise_rows:
|
733 |
+
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
734 |
+
log["denoise_row"] = denoise_grid
|
735 |
+
# import pudb; pudb.set_trace()
|
736 |
+
if unconditional_guidance_scale > 1.0:
|
737 |
+
uc_cross = self.get_unconditional_conditioning(N)
|
738 |
+
uc_cat = c_cat # torch.zeros_like(c_cat)
|
739 |
+
uc_full = {"c_concat": [uc_cat], "c_crossattn": [uc_cross]}
|
740 |
+
samples_cfg, _ = self.sample_log(cond={"c_concat": [c_cat], "c_crossattn": [c]},
|
741 |
+
batch_size=N, ddim=use_ddim,
|
742 |
+
ddim_steps=ddim_steps, eta=ddim_eta,
|
743 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
744 |
+
unconditional_conditioning=uc_full,
|
745 |
+
)
|
746 |
+
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
747 |
+
log[f"samples_cfg_scale_{unconditional_guidance_scale:.2f}"] = x_samples_cfg
|
748 |
+
|
749 |
+
return log
|
750 |
+
|
751 |
+
@torch.no_grad()
|
752 |
+
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
753 |
+
ddim_sampler = DDIMSampler(self)
|
754 |
+
# import pdb; pdb.set_trace()
|
755 |
+
# b, c, h, w = cond["c_concat"][0].shape
|
756 |
+
b, c, h, w = cond["c_concat"][0].shape[0], self.channels, self.image_size*8, self.image_size*8
|
757 |
+
shape = (self.channels, h // 8, w // 8)
|
758 |
+
samples, intermediates = ddim_sampler.sample(ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
759 |
+
return samples, intermediates
|
760 |
+
|
761 |
+
def configure_optimizers(self):
|
762 |
+
lr = self.learning_rate
|
763 |
+
params = list(self.control_model.parameters())
|
764 |
+
if self.secret_decoder is not None:
|
765 |
+
params += list(self.secret_decoder.parameters())
|
766 |
+
if not self.sd_locked:
|
767 |
+
params += list(self.model.diffusion_model.output_blocks.parameters())
|
768 |
+
params += list(self.model.diffusion_model.out.parameters())
|
769 |
+
opt = torch.optim.AdamW(params, lr=lr)
|
770 |
+
return opt
|
771 |
+
|
772 |
+
def low_vram_shift(self, is_diffusing):
|
773 |
+
if is_diffusing:
|
774 |
+
self.model = self.model.cuda()
|
775 |
+
self.control_model = self.control_model.cuda()
|
776 |
+
self.first_stage_model = self.first_stage_model.cpu()
|
777 |
+
self.cond_stage_model = self.cond_stage_model.cpu()
|
778 |
+
else:
|
779 |
+
self.model = self.model.cpu()
|
780 |
+
self.control_model = self.control_model.cpu()
|
781 |
+
self.first_stage_model = self.first_stage_model.cuda()
|
782 |
+
self.cond_stage_model = self.cond_stage_model.cuda()
|
cldm/hack.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import einops
|
3 |
+
|
4 |
+
import ldm.modules.encoders.modules
|
5 |
+
import ldm.modules.attention
|
6 |
+
|
7 |
+
from transformers import logging
|
8 |
+
from ldm.modules.attention import default
|
9 |
+
import warnings
|
10 |
+
|
11 |
+
def disable_verbosity():
|
12 |
+
logging.set_verbosity_error()
|
13 |
+
warnings.filterwarnings(action='ignore', category=DeprecationWarning)
|
14 |
+
warnings.filterwarnings(action='ignore', category=UserWarning)
|
15 |
+
print('logging improved.')
|
16 |
+
return
|
17 |
+
|
18 |
+
|
19 |
+
def enable_sliced_attention():
|
20 |
+
ldm.modules.attention.CrossAttention.forward = _hacked_sliced_attentin_forward
|
21 |
+
print('Enabled sliced_attention.')
|
22 |
+
return
|
23 |
+
|
24 |
+
|
25 |
+
def hack_everything(clip_skip=0):
|
26 |
+
disable_verbosity()
|
27 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.forward = _hacked_clip_forward
|
28 |
+
ldm.modules.encoders.modules.FrozenCLIPEmbedder.clip_skip = clip_skip
|
29 |
+
print('Enabled clip hacks.')
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
# Written by Lvmin
|
34 |
+
def _hacked_clip_forward(self, text):
|
35 |
+
PAD = self.tokenizer.pad_token_id
|
36 |
+
EOS = self.tokenizer.eos_token_id
|
37 |
+
BOS = self.tokenizer.bos_token_id
|
38 |
+
|
39 |
+
def tokenize(t):
|
40 |
+
return self.tokenizer(t, truncation=False, add_special_tokens=False)["input_ids"]
|
41 |
+
|
42 |
+
def transformer_encode(t):
|
43 |
+
if self.clip_skip > 1:
|
44 |
+
rt = self.transformer(input_ids=t, output_hidden_states=True)
|
45 |
+
return self.transformer.text_model.final_layer_norm(rt.hidden_states[-self.clip_skip])
|
46 |
+
else:
|
47 |
+
return self.transformer(input_ids=t, output_hidden_states=False).last_hidden_state
|
48 |
+
|
49 |
+
def split(x):
|
50 |
+
return x[75 * 0: 75 * 1], x[75 * 1: 75 * 2], x[75 * 2: 75 * 3]
|
51 |
+
|
52 |
+
def pad(x, p, i):
|
53 |
+
return x[:i] if len(x) >= i else x + [p] * (i - len(x))
|
54 |
+
|
55 |
+
raw_tokens_list = tokenize(text)
|
56 |
+
tokens_list = []
|
57 |
+
|
58 |
+
for raw_tokens in raw_tokens_list:
|
59 |
+
raw_tokens_123 = split(raw_tokens)
|
60 |
+
raw_tokens_123 = [[BOS] + raw_tokens_i + [EOS] for raw_tokens_i in raw_tokens_123]
|
61 |
+
raw_tokens_123 = [pad(raw_tokens_i, PAD, 77) for raw_tokens_i in raw_tokens_123]
|
62 |
+
tokens_list.append(raw_tokens_123)
|
63 |
+
|
64 |
+
tokens_list = torch.IntTensor(tokens_list).to(self.device)
|
65 |
+
|
66 |
+
feed = einops.rearrange(tokens_list, 'b f i -> (b f) i')
|
67 |
+
y = transformer_encode(feed)
|
68 |
+
z = einops.rearrange(y, '(b f) i c -> b (f i) c', f=3)
|
69 |
+
|
70 |
+
return z
|
71 |
+
|
72 |
+
|
73 |
+
# Stolen from https://github.com/basujindal/stable-diffusion/blob/main/optimizedSD/splitAttention.py
|
74 |
+
def _hacked_sliced_attentin_forward(self, x, context=None, mask=None):
|
75 |
+
h = self.heads
|
76 |
+
|
77 |
+
q = self.to_q(x)
|
78 |
+
context = default(context, x)
|
79 |
+
k = self.to_k(context)
|
80 |
+
v = self.to_v(context)
|
81 |
+
del context, x
|
82 |
+
|
83 |
+
q, k, v = map(lambda t: einops.rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
84 |
+
|
85 |
+
limit = k.shape[0]
|
86 |
+
att_step = 1
|
87 |
+
q_chunks = list(torch.tensor_split(q, limit // att_step, dim=0))
|
88 |
+
k_chunks = list(torch.tensor_split(k, limit // att_step, dim=0))
|
89 |
+
v_chunks = list(torch.tensor_split(v, limit // att_step, dim=0))
|
90 |
+
|
91 |
+
q_chunks.reverse()
|
92 |
+
k_chunks.reverse()
|
93 |
+
v_chunks.reverse()
|
94 |
+
sim = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
|
95 |
+
del k, q, v
|
96 |
+
for i in range(0, limit, att_step):
|
97 |
+
q_buffer = q_chunks.pop()
|
98 |
+
k_buffer = k_chunks.pop()
|
99 |
+
v_buffer = v_chunks.pop()
|
100 |
+
sim_buffer = torch.einsum('b i d, b j d -> b i j', q_buffer, k_buffer) * self.scale
|
101 |
+
|
102 |
+
del k_buffer, q_buffer
|
103 |
+
# attention, what we cannot get enough of, by chunks
|
104 |
+
|
105 |
+
sim_buffer = sim_buffer.softmax(dim=-1)
|
106 |
+
|
107 |
+
sim_buffer = torch.einsum('b i j, b j d -> b i d', sim_buffer, v_buffer)
|
108 |
+
del v_buffer
|
109 |
+
sim[i:i + att_step, :, :] = sim_buffer
|
110 |
+
|
111 |
+
del sim_buffer
|
112 |
+
sim = einops.rearrange(sim, '(b h) n d -> b n (h d)', h=h)
|
113 |
+
return self.to_out(sim)
|
cldm/logger.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from omegaconf import OmegaConf
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torchvision
|
6 |
+
from PIL import Image
|
7 |
+
from pytorch_lightning.callbacks import Callback
|
8 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
9 |
+
from pytorch_lightning.utilities import rank_zero_info
|
10 |
+
import time
|
11 |
+
|
12 |
+
|
13 |
+
class CUDACallback(Callback):
|
14 |
+
# see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
|
15 |
+
def on_train_epoch_start(self, trainer, pl_module):
|
16 |
+
# Reset the memory use counter
|
17 |
+
torch.cuda.reset_peak_memory_stats(trainer.root_gpu)
|
18 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
19 |
+
self.start_time = time.time()
|
20 |
+
|
21 |
+
def on_train_epoch_end(self, trainer, pl_module, outputs):
|
22 |
+
torch.cuda.synchronize(trainer.root_gpu)
|
23 |
+
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2 ** 20
|
24 |
+
epoch_time = (time.time() - self.start_time)/3600
|
25 |
+
|
26 |
+
try:
|
27 |
+
max_memory = trainer.training_type_plugin.reduce(max_memory)
|
28 |
+
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
|
29 |
+
|
30 |
+
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} hours")
|
31 |
+
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
|
32 |
+
except AttributeError:
|
33 |
+
pass
|
34 |
+
|
35 |
+
|
36 |
+
class SetupCallback(Callback):
|
37 |
+
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
|
38 |
+
super().__init__()
|
39 |
+
self.resume = resume
|
40 |
+
self.now = now
|
41 |
+
self.logdir = logdir
|
42 |
+
self.ckptdir = ckptdir
|
43 |
+
self.cfgdir = cfgdir
|
44 |
+
self.config = config
|
45 |
+
self.lightning_config = lightning_config
|
46 |
+
|
47 |
+
def on_keyboard_interrupt(self, trainer, pl_module):
|
48 |
+
if trainer.global_rank == 0:
|
49 |
+
print("Summoning checkpoint.")
|
50 |
+
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
|
51 |
+
trainer.save_checkpoint(ckpt_path)
|
52 |
+
|
53 |
+
def on_pretrain_routine_start(self, trainer, pl_module):
|
54 |
+
if trainer.global_rank == 0:
|
55 |
+
# Create logdirs and save configs
|
56 |
+
os.makedirs(self.logdir, exist_ok=True)
|
57 |
+
os.makedirs(self.ckptdir, exist_ok=True)
|
58 |
+
os.makedirs(self.cfgdir, exist_ok=True)
|
59 |
+
|
60 |
+
if "callbacks" in self.lightning_config:
|
61 |
+
if 'metrics_over_trainsteps_checkpoint' in self.lightning_config['callbacks']:
|
62 |
+
os.makedirs(os.path.join(self.ckptdir, 'trainstep_checkpoints'), exist_ok=True)
|
63 |
+
print("Project config")
|
64 |
+
print(OmegaConf.to_yaml(self.config))
|
65 |
+
OmegaConf.save(self.config,
|
66 |
+
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
|
67 |
+
|
68 |
+
print("Lightning config")
|
69 |
+
print(OmegaConf.to_yaml(self.lightning_config))
|
70 |
+
OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
|
71 |
+
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))
|
72 |
+
|
73 |
+
else:
|
74 |
+
# ModelCheckpoint callback created log directory --- remove it
|
75 |
+
if not self.resume and os.path.exists(self.logdir):
|
76 |
+
dst, name = os.path.split(self.logdir)
|
77 |
+
dst = os.path.join(dst, "child_runs", name)
|
78 |
+
os.makedirs(os.path.split(dst)[0], exist_ok=True)
|
79 |
+
try:
|
80 |
+
os.rename(self.logdir, dst)
|
81 |
+
except FileNotFoundError:
|
82 |
+
pass
|
83 |
+
|
84 |
+
class ImageLogger(Callback):
|
85 |
+
def __init__(self, batch_frequency=2000, max_images=4, clamp=True, increase_log_steps=True,
|
86 |
+
rescale=True, disabled=False, log_on_batch_idx=False, log_first_step=False,
|
87 |
+
log_images_kwargs=None, fixed_input=False):
|
88 |
+
super().__init__()
|
89 |
+
self.rescale = rescale
|
90 |
+
self.batch_freq = batch_frequency
|
91 |
+
self.max_images = max_images
|
92 |
+
if not increase_log_steps:
|
93 |
+
self.log_steps = [self.batch_freq]
|
94 |
+
self.clamp = clamp
|
95 |
+
self.disabled = disabled
|
96 |
+
self.log_on_batch_idx = log_on_batch_idx
|
97 |
+
self.log_images_kwargs = log_images_kwargs if log_images_kwargs else {}
|
98 |
+
self.log_first_step = log_first_step
|
99 |
+
self.fixed_input = fixed_input
|
100 |
+
|
101 |
+
@rank_zero_only
|
102 |
+
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
|
103 |
+
root = os.path.join(save_dir, "image_log", split)
|
104 |
+
for k in images:
|
105 |
+
grid = torchvision.utils.make_grid(images[k], nrow=4)
|
106 |
+
if self.rescale:
|
107 |
+
grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
|
108 |
+
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
109 |
+
grid = grid.numpy()
|
110 |
+
grid = (grid * 255).astype(np.uint8)
|
111 |
+
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
|
112 |
+
path = os.path.join(root, filename)
|
113 |
+
os.makedirs(os.path.split(path)[0], exist_ok=True)
|
114 |
+
Image.fromarray(grid).save(path)
|
115 |
+
|
116 |
+
def log_img(self, pl_module, batch, batch_idx, split="train"):
|
117 |
+
check_idx = batch_idx # if self.log_on_batch_idx else pl_module.global_step
|
118 |
+
if (self.check_frequency(check_idx) and # batch_idx % self.batch_freq == 0
|
119 |
+
hasattr(pl_module, "log_images") and
|
120 |
+
callable(pl_module.log_images) and
|
121 |
+
self.max_images > 0):
|
122 |
+
logger = type(pl_module.logger)
|
123 |
+
|
124 |
+
is_train = pl_module.training
|
125 |
+
if is_train:
|
126 |
+
pl_module.eval()
|
127 |
+
|
128 |
+
with torch.no_grad():
|
129 |
+
images = pl_module.log_images(batch, fixed_input=self.fixed_input, split=split, **self.log_images_kwargs)
|
130 |
+
|
131 |
+
for k in images:
|
132 |
+
N = min(images[k].shape[0], self.max_images)
|
133 |
+
images[k] = images[k][:N]
|
134 |
+
if isinstance(images[k], torch.Tensor):
|
135 |
+
images[k] = images[k].detach().cpu()
|
136 |
+
if self.clamp:
|
137 |
+
images[k] = torch.clamp(images[k], -1., 1.)
|
138 |
+
self.log_local(pl_module.logger.save_dir, split, images,
|
139 |
+
pl_module.global_step, pl_module.current_epoch, batch_idx)
|
140 |
+
|
141 |
+
if is_train:
|
142 |
+
pl_module.train()
|
143 |
+
|
144 |
+
def check_frequency(self, check_idx):
|
145 |
+
return check_idx % self.batch_freq == 0
|
146 |
+
|
147 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
148 |
+
if not self.disabled:
|
149 |
+
self.log_img(pl_module, batch, batch_idx, split="train")
|
cldm/loss.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from lpips import LPIPS
|
4 |
+
from kornia import color
|
5 |
+
# from taming.modules.losses.vqperceptual import *
|
6 |
+
|
7 |
+
class ImageSecretLoss(nn.Module):
|
8 |
+
def __init__(self, recon_type='rgb', recon_weight=1., perceptual_weight=1.0, secret_weight=10., kl_weight=0.000001, logvar_init=0.0, ramp=100000, max_image_weight_ratio=2.) -> None:
|
9 |
+
super().__init__()
|
10 |
+
self.recon_type = recon_type
|
11 |
+
assert recon_type in ['rgb', 'yuv']
|
12 |
+
if recon_type == 'yuv':
|
13 |
+
self.register_buffer('yuv_scales', torch.tensor([1,100,100]).unsqueeze(1).float()) # [3,1]
|
14 |
+
self.recon_weight = recon_weight
|
15 |
+
self.perceptual_weight = perceptual_weight
|
16 |
+
self.secret_weight = secret_weight
|
17 |
+
self.kl_weight = kl_weight
|
18 |
+
|
19 |
+
self.ramp = ramp
|
20 |
+
self.max_image_weight = max_image_weight_ratio * secret_weight - 1
|
21 |
+
self.register_buffer('ramp_on', torch.tensor(False))
|
22 |
+
self.register_buffer('step0', torch.tensor(1e9)) # large number
|
23 |
+
|
24 |
+
self.perceptual_loss = LPIPS().eval()
|
25 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
26 |
+
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
27 |
+
|
28 |
+
def activate_ramp(self, global_step):
|
29 |
+
if not self.ramp_on: # do not activate ramp twice
|
30 |
+
self.step0 = torch.tensor(global_step)
|
31 |
+
self.ramp_on = ~self.ramp_on
|
32 |
+
print('[TRAINING] Activate ramp for image loss at step ', global_step)
|
33 |
+
|
34 |
+
def compute_recon_loss(self, inputs, reconstructions):
|
35 |
+
if self.recon_type == 'rgb':
|
36 |
+
rec_loss = torch.abs(inputs - reconstructions).mean(dim=[1,2,3])
|
37 |
+
elif self.recon_type == 'yuv':
|
38 |
+
reconstructions_yuv = color.rgb_to_yuv((reconstructions + 1) / 2)
|
39 |
+
inputs_yuv = color.rgb_to_yuv((inputs + 1) / 2)
|
40 |
+
yuv_loss = torch.mean((reconstructions_yuv - inputs_yuv)**2, dim=[2,3])
|
41 |
+
rec_loss = torch.mm(yuv_loss, self.yuv_scales).squeeze(1)
|
42 |
+
else:
|
43 |
+
raise ValueError(f"Unknown recon type {self.recon_type}")
|
44 |
+
return rec_loss
|
45 |
+
|
46 |
+
def forward(self, inputs, reconstructions, posteriors, secret_gt, secret_pred, global_step):
|
47 |
+
loss_dict = {}
|
48 |
+
rec_loss = self.compute_recon_loss(inputs.contiguous(), reconstructions.contiguous())
|
49 |
+
|
50 |
+
loss = rec_loss*self.recon_weight
|
51 |
+
|
52 |
+
if self.perceptual_weight > 0:
|
53 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()).mean(dim=[1,2,3])
|
54 |
+
loss += self.perceptual_weight * p_loss
|
55 |
+
loss_dict['p_loss'] = p_loss.mean()
|
56 |
+
|
57 |
+
loss = loss / torch.exp(self.logvar) + self.logvar
|
58 |
+
if self.kl_weight > 0:
|
59 |
+
kl_loss = posteriors.kl()
|
60 |
+
loss += kl_loss*self.kl_weight
|
61 |
+
loss_dict['kl_loss'] = kl_loss.mean()
|
62 |
+
|
63 |
+
image_weight = 1 + min(self.max_image_weight, max(0., self.max_image_weight*(global_step - self.step0.item())/self.ramp))
|
64 |
+
|
65 |
+
secret_loss = self.bce(secret_pred, secret_gt).mean(dim=1)
|
66 |
+
loss = (loss*image_weight + secret_loss*self.secret_weight) / (image_weight+self.secret_weight)
|
67 |
+
|
68 |
+
# loss dict update
|
69 |
+
bit_acc = ((secret_pred.detach() > 0).float() == secret_gt).float().mean()
|
70 |
+
loss_dict['bit_acc'] = bit_acc
|
71 |
+
loss_dict['loss'] = loss.mean()
|
72 |
+
loss_dict['img_lw'] = image_weight/self.secret_weight
|
73 |
+
loss_dict['rec_loss'] = rec_loss.mean()
|
74 |
+
loss_dict['secret_loss'] = secret_loss.mean()
|
75 |
+
|
76 |
+
return loss.mean(), loss_dict
|
77 |
+
|
78 |
+
|
cldm/loss_weight_scheduler.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
|
5 |
+
@author: Tu Bui @University of Surrey
|
6 |
+
"""
|
7 |
+
|
8 |
+
class SimpleLossWeightScheduler(object):
|
9 |
+
def __init__(self, simple_loss_weight_max=10., wait_steps=50000, ramp=100000) -> None:
|
10 |
+
self.simple_loss_weight_max = simple_loss_weight_max
|
11 |
+
self.wait_steps = wait_steps
|
12 |
+
self.ramp = ramp
|
13 |
+
|
14 |
+
def __call__(self, step):
|
15 |
+
max_weight = self.simple_loss_weight_max - 1
|
16 |
+
w = 1 + min(max_weight, max(0., max_weight*(step - self.wait_steps)/self.ramp))
|
17 |
+
return w
|
cldm/model.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from ldm.util import instantiate_from_config
|
6 |
+
|
7 |
+
|
8 |
+
def get_state_dict(d):
|
9 |
+
return d.get('state_dict', d)
|
10 |
+
|
11 |
+
|
12 |
+
def load_state_dict(ckpt_path, location='cpu'):
|
13 |
+
_, extension = os.path.splitext(ckpt_path)
|
14 |
+
if extension.lower() == ".safetensors":
|
15 |
+
import safetensors.torch
|
16 |
+
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
|
17 |
+
else:
|
18 |
+
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
|
19 |
+
state_dict = get_state_dict(state_dict)
|
20 |
+
print(f'Loaded state_dict from [{ckpt_path}]')
|
21 |
+
return state_dict
|
22 |
+
|
23 |
+
|
24 |
+
def create_model(config_path):
|
25 |
+
config = OmegaConf.load(config_path)
|
26 |
+
model = instantiate_from_config(config.model).cpu()
|
27 |
+
print(f'Loaded model config from [{config_path}]')
|
28 |
+
return model
|
cldm/plms.py
ADDED
@@ -0,0 +1,1481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""SAMPLING ONLY."""
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
import torchvision
|
6 |
+
import numpy as np
|
7 |
+
from tqdm import tqdm
|
8 |
+
from functools import partial
|
9 |
+
from PIL import Image
|
10 |
+
import shutil
|
11 |
+
|
12 |
+
from ldm.modules.diffusionmodules.util import (
|
13 |
+
make_ddim_sampling_parameters,
|
14 |
+
make_ddim_timesteps,
|
15 |
+
noise_like,
|
16 |
+
)
|
17 |
+
import clip
|
18 |
+
from einops import rearrange
|
19 |
+
import random
|
20 |
+
|
21 |
+
|
22 |
+
class VGGPerceptualLoss(torch.nn.Module):
|
23 |
+
def __init__(self, resize=True):
|
24 |
+
super(VGGPerceptualLoss, self).__init__()
|
25 |
+
blocks = []
|
26 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[:4].eval())
|
27 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[4:9].eval())
|
28 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[9:16].eval())
|
29 |
+
blocks.append(torchvision.models.vgg16(pretrained=True).features[16:23].eval())
|
30 |
+
for bl in blocks:
|
31 |
+
for p in bl.parameters():
|
32 |
+
p.requires_grad = False
|
33 |
+
self.blocks = torch.nn.ModuleList(blocks)
|
34 |
+
self.transform = torch.nn.functional.interpolate
|
35 |
+
self.resize = resize
|
36 |
+
self.register_buffer(
|
37 |
+
"mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
38 |
+
)
|
39 |
+
self.register_buffer(
|
40 |
+
"std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
41 |
+
)
|
42 |
+
|
43 |
+
def forward(self, input, target, feature_layers=[0, 1, 2, 3], style_layers=[]):
|
44 |
+
input = (input - self.mean) / self.std
|
45 |
+
target = (target - self.mean) / self.std
|
46 |
+
if self.resize:
|
47 |
+
input = self.transform(
|
48 |
+
input, mode="bilinear", size=(224, 224), align_corners=False
|
49 |
+
)
|
50 |
+
target = self.transform(
|
51 |
+
target, mode="bilinear", size=(224, 224), align_corners=False
|
52 |
+
)
|
53 |
+
loss = 0.0
|
54 |
+
x = input
|
55 |
+
y = target
|
56 |
+
for i, block in enumerate(self.blocks):
|
57 |
+
x = block(x)
|
58 |
+
y = block(y)
|
59 |
+
if i in feature_layers:
|
60 |
+
loss += torch.nn.functional.l1_loss(x, y)
|
61 |
+
if i in style_layers:
|
62 |
+
act_x = x.reshape(x.shape[0], x.shape[1], -1)
|
63 |
+
act_y = y.reshape(y.shape[0], y.shape[1], -1)
|
64 |
+
gram_x = act_x @ act_x.permute(0, 2, 1)
|
65 |
+
gram_y = act_y @ act_y.permute(0, 2, 1)
|
66 |
+
loss += torch.nn.functional.l1_loss(gram_x, gram_y)
|
67 |
+
return loss
|
68 |
+
|
69 |
+
|
70 |
+
class DCLIPLoss(torch.nn.Module):
|
71 |
+
def __init__(self):
|
72 |
+
super(DCLIPLoss, self).__init__()
|
73 |
+
self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
|
74 |
+
self.upsample = torch.nn.Upsample(scale_factor=7)
|
75 |
+
self.avg_pool = torch.nn.AvgPool2d(kernel_size=16)
|
76 |
+
|
77 |
+
def forward(self, image1, image2, text1, text2):
|
78 |
+
text1 = clip.tokenize([text1]).to("cuda")
|
79 |
+
text2 = clip.tokenize([text2]).to("cuda")
|
80 |
+
image1 = image1.unsqueeze(0).cuda()
|
81 |
+
image2 = image2.unsqueeze(0)
|
82 |
+
image1 = self.avg_pool(self.upsample(image1))
|
83 |
+
image2 = self.avg_pool(self.upsample(image2))
|
84 |
+
image1_feat = self.model.encode_image(image1)
|
85 |
+
image2_feat = self.model.encode_image(image2)
|
86 |
+
text1_feat = self.model.encode_text(text1)
|
87 |
+
text2_feat = self.model.encode_text(text2)
|
88 |
+
d_image_feat = image1_feat - image2_feat
|
89 |
+
d_text_feat = text1_feat - text2_feat
|
90 |
+
similarity = torch.nn.CosineSimilarity()(d_image_feat, d_text_feat)
|
91 |
+
return 1 - similarity
|
92 |
+
|
93 |
+
|
94 |
+
class PLMSSampler(object):
|
95 |
+
def __init__(self, model, schedule="linear", **kwargs):
|
96 |
+
super().__init__()
|
97 |
+
self.model = model
|
98 |
+
self.ddpm_num_timesteps = model.num_timesteps
|
99 |
+
self.schedule = schedule
|
100 |
+
|
101 |
+
def register_buffer(self, name, attr):
|
102 |
+
if type(attr) == torch.Tensor:
|
103 |
+
if attr.device != torch.device("cuda"):
|
104 |
+
attr = attr.to(torch.device("cuda"))
|
105 |
+
setattr(self, name, attr)
|
106 |
+
|
107 |
+
def make_schedule(
|
108 |
+
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
|
109 |
+
):
|
110 |
+
if ddim_eta != 0:
|
111 |
+
raise ValueError("ddim_eta must be 0 for PLMS")
|
112 |
+
self.ddim_timesteps = make_ddim_timesteps(
|
113 |
+
ddim_discr_method=ddim_discretize,
|
114 |
+
num_ddim_timesteps=ddim_num_steps,
|
115 |
+
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
116 |
+
verbose=verbose,
|
117 |
+
)
|
118 |
+
alphas_cumprod = self.model.alphas_cumprod
|
119 |
+
assert (
|
120 |
+
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
|
121 |
+
), "alphas have to be defined for each timestep"
|
122 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
|
123 |
+
|
124 |
+
self.register_buffer("betas", to_torch(self.model.betas))
|
125 |
+
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
|
126 |
+
self.register_buffer(
|
127 |
+
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
|
128 |
+
)
|
129 |
+
|
130 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
131 |
+
self.register_buffer(
|
132 |
+
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
|
133 |
+
)
|
134 |
+
self.register_buffer(
|
135 |
+
"sqrt_one_minus_alphas_cumprod",
|
136 |
+
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
|
137 |
+
)
|
138 |
+
self.register_buffer(
|
139 |
+
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
|
140 |
+
)
|
141 |
+
self.register_buffer(
|
142 |
+
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
|
143 |
+
)
|
144 |
+
self.register_buffer(
|
145 |
+
"sqrt_recipm1_alphas_cumprod",
|
146 |
+
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
|
147 |
+
)
|
148 |
+
|
149 |
+
# ddim sampling parameters
|
150 |
+
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
151 |
+
alphacums=alphas_cumprod.cpu(),
|
152 |
+
ddim_timesteps=self.ddim_timesteps,
|
153 |
+
eta=0.0,
|
154 |
+
verbose=verbose,
|
155 |
+
)
|
156 |
+
self.register_buffer("ddim_sigmas", ddim_sigmas)
|
157 |
+
self.register_buffer("ddim_alphas", ddim_alphas)
|
158 |
+
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
|
159 |
+
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
|
160 |
+
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
|
161 |
+
(1 - self.alphas_cumprod_prev)
|
162 |
+
/ (1 - self.alphas_cumprod)
|
163 |
+
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
164 |
+
)
|
165 |
+
self.register_buffer(
|
166 |
+
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
|
167 |
+
)
|
168 |
+
|
169 |
+
@torch.no_grad()
|
170 |
+
def sample(self,
|
171 |
+
S,
|
172 |
+
batch_size,
|
173 |
+
shape,
|
174 |
+
conditioning=None,
|
175 |
+
callback=None,
|
176 |
+
normals_sequence=None,
|
177 |
+
img_callback=None,
|
178 |
+
quantize_x0=False,
|
179 |
+
eta=0.,
|
180 |
+
mask=None,
|
181 |
+
x0=None,
|
182 |
+
temperature=1.,
|
183 |
+
noise_dropout=0.,
|
184 |
+
score_corrector=None,
|
185 |
+
corrector_kwargs=None,
|
186 |
+
verbose=True,
|
187 |
+
x_T=None,
|
188 |
+
log_every_t=100,
|
189 |
+
unconditional_guidance_scale=1.,
|
190 |
+
unconditional_conditioning=None,
|
191 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
192 |
+
dynamic_threshold=None,
|
193 |
+
**kwargs
|
194 |
+
):
|
195 |
+
if conditioning is not None:
|
196 |
+
if isinstance(conditioning, dict):
|
197 |
+
cbs = conditioning[list(conditioning.keys())[0]][0].shape[0]
|
198 |
+
if cbs != batch_size:
|
199 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
200 |
+
else:
|
201 |
+
if conditioning.shape[0] != batch_size:
|
202 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
203 |
+
|
204 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
205 |
+
# sampling
|
206 |
+
C, H, W = shape
|
207 |
+
size = (batch_size, C, H, W)
|
208 |
+
print(f'Data shape for PLMS sampling is {size}')
|
209 |
+
|
210 |
+
samples, intermediates = self.plms_sampling(conditioning, size,
|
211 |
+
callback=callback,
|
212 |
+
img_callback=img_callback,
|
213 |
+
quantize_denoised=quantize_x0,
|
214 |
+
mask=mask, x0=x0,
|
215 |
+
ddim_use_original_steps=False,
|
216 |
+
noise_dropout=noise_dropout,
|
217 |
+
temperature=temperature,
|
218 |
+
score_corrector=score_corrector,
|
219 |
+
corrector_kwargs=corrector_kwargs,
|
220 |
+
x_T=x_T,
|
221 |
+
log_every_t=log_every_t,
|
222 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
223 |
+
unconditional_conditioning=unconditional_conditioning,
|
224 |
+
)
|
225 |
+
return samples, intermediates
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def plms_sampling(
|
229 |
+
self,
|
230 |
+
cond,
|
231 |
+
shape,
|
232 |
+
x_T=None,
|
233 |
+
ddim_use_original_steps=False,
|
234 |
+
callback=None,
|
235 |
+
timesteps=None,
|
236 |
+
quantize_denoised=False,
|
237 |
+
mask=None,
|
238 |
+
x0=None,
|
239 |
+
img_callback=None,
|
240 |
+
log_every_t=100,
|
241 |
+
temperature=1.0,
|
242 |
+
noise_dropout=0.0,
|
243 |
+
score_corrector=None,
|
244 |
+
corrector_kwargs=None,
|
245 |
+
unconditional_guidance_scale=1.0,
|
246 |
+
unconditional_conditioning=None,
|
247 |
+
):
|
248 |
+
device = self.model.betas.device
|
249 |
+
b = shape[0]
|
250 |
+
if x_T is None:
|
251 |
+
img = torch.randn(shape, device=device)
|
252 |
+
else:
|
253 |
+
img = x_T
|
254 |
+
|
255 |
+
if timesteps is None:
|
256 |
+
timesteps = (
|
257 |
+
self.ddpm_num_timesteps
|
258 |
+
if ddim_use_original_steps
|
259 |
+
else self.ddim_timesteps
|
260 |
+
)
|
261 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
262 |
+
subset_end = (
|
263 |
+
int(
|
264 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
265 |
+
* self.ddim_timesteps.shape[0]
|
266 |
+
)
|
267 |
+
- 1
|
268 |
+
)
|
269 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
270 |
+
|
271 |
+
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
272 |
+
time_range = (
|
273 |
+
list(reversed(range(0, timesteps)))
|
274 |
+
if ddim_use_original_steps
|
275 |
+
else np.flip(timesteps)
|
276 |
+
)
|
277 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
278 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
279 |
+
|
280 |
+
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps)
|
281 |
+
old_eps = []
|
282 |
+
|
283 |
+
for i, step in enumerate(iterator):
|
284 |
+
index = total_steps - i - 1
|
285 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
286 |
+
ts_next = torch.full(
|
287 |
+
(b,),
|
288 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
289 |
+
device=device,
|
290 |
+
dtype=torch.long,
|
291 |
+
)
|
292 |
+
|
293 |
+
if mask is not None:
|
294 |
+
assert x0 is not None
|
295 |
+
# import ipdb; ipdb.set_trace()
|
296 |
+
img_orig = self.model.q_sample(
|
297 |
+
x0, ts
|
298 |
+
) # TODO: deterministic forward pass?
|
299 |
+
img = img_orig * mask + (1.0 - mask) * img
|
300 |
+
|
301 |
+
outs = self.p_sample_plms(
|
302 |
+
img,
|
303 |
+
cond,
|
304 |
+
ts,
|
305 |
+
index=index,
|
306 |
+
use_original_steps=ddim_use_original_steps,
|
307 |
+
quantize_denoised=quantize_denoised,
|
308 |
+
temperature=temperature,
|
309 |
+
noise_dropout=noise_dropout,
|
310 |
+
score_corrector=score_corrector,
|
311 |
+
corrector_kwargs=corrector_kwargs,
|
312 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
313 |
+
unconditional_conditioning=unconditional_conditioning,
|
314 |
+
old_eps=old_eps,
|
315 |
+
t_next=ts_next,
|
316 |
+
)
|
317 |
+
img, pred_x0, e_t = outs
|
318 |
+
old_eps.append(e_t)
|
319 |
+
if len(old_eps) >= 4:
|
320 |
+
old_eps.pop(0)
|
321 |
+
if callback:
|
322 |
+
callback(i)
|
323 |
+
if img_callback:
|
324 |
+
img_callback(pred_x0, i)
|
325 |
+
|
326 |
+
if index % 1 == 0 or index == total_steps - 1:
|
327 |
+
intermediates["x_inter"].append(img)
|
328 |
+
intermediates["pred_x0"].append(pred_x0)
|
329 |
+
|
330 |
+
return img, intermediates
|
331 |
+
|
332 |
+
@torch.no_grad()
|
333 |
+
def p_sample_plms(
|
334 |
+
self,
|
335 |
+
x,
|
336 |
+
c,
|
337 |
+
t,
|
338 |
+
index,
|
339 |
+
repeat_noise=False,
|
340 |
+
use_original_steps=False,
|
341 |
+
quantize_denoised=False,
|
342 |
+
temperature=1.0,
|
343 |
+
noise_dropout=0.0,
|
344 |
+
score_corrector=None,
|
345 |
+
corrector_kwargs=None,
|
346 |
+
unconditional_guidance_scale=1.0,
|
347 |
+
unconditional_conditioning=None,
|
348 |
+
old_eps=None,
|
349 |
+
t_next=None,
|
350 |
+
):
|
351 |
+
b, *_, device = *x.shape, x.device
|
352 |
+
|
353 |
+
def get_model_output(x, t):
|
354 |
+
if (
|
355 |
+
unconditional_conditioning is None
|
356 |
+
or unconditional_guidance_scale == 1.0
|
357 |
+
):
|
358 |
+
e_t = self.model.apply_model(x, t, c)
|
359 |
+
else:
|
360 |
+
x_in = torch.cat([x] * 2)
|
361 |
+
t_in = torch.cat([t] * 2)
|
362 |
+
if isinstance(c, dict):
|
363 |
+
c_in = {key: [torch.cat([unconditional_conditioning[key][0], c[key][0]])] for key in c}
|
364 |
+
else:
|
365 |
+
c_in = torch.cat([unconditional_conditioning, c])
|
366 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
367 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
368 |
+
|
369 |
+
if score_corrector is not None:
|
370 |
+
assert self.model.parameterization == "eps"
|
371 |
+
e_t = score_corrector.modify_score(
|
372 |
+
self.model, e_t, x, t, c, **corrector_kwargs
|
373 |
+
)
|
374 |
+
|
375 |
+
return e_t
|
376 |
+
|
377 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
378 |
+
alphas_prev = (
|
379 |
+
self.model.alphas_cumprod_prev
|
380 |
+
if use_original_steps
|
381 |
+
else self.ddim_alphas_prev
|
382 |
+
)
|
383 |
+
sqrt_one_minus_alphas = (
|
384 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
385 |
+
if use_original_steps
|
386 |
+
else self.ddim_sqrt_one_minus_alphas
|
387 |
+
)
|
388 |
+
sigmas = (
|
389 |
+
self.model.ddim_sigmas_for_original_num_steps
|
390 |
+
if use_original_steps
|
391 |
+
else self.ddim_sigmas
|
392 |
+
)
|
393 |
+
|
394 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
395 |
+
# select parameters corresponding to the currently considered timestep
|
396 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
397 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
398 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
399 |
+
sqrt_one_minus_at = torch.full(
|
400 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
401 |
+
)
|
402 |
+
|
403 |
+
# current prediction for x_0
|
404 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
405 |
+
if quantize_denoised:
|
406 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
407 |
+
# direction pointing to x_t
|
408 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
409 |
+
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
|
410 |
+
if noise_dropout > 0.0:
|
411 |
+
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
412 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
413 |
+
return x_prev, pred_x0
|
414 |
+
|
415 |
+
e_t = get_model_output(x, t)
|
416 |
+
if len(old_eps) == 0:
|
417 |
+
# Pseudo Improved Euler (2nd order)
|
418 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
419 |
+
e_t_next = get_model_output(x_prev, t_next)
|
420 |
+
e_t_prime = (e_t + e_t_next) / 2
|
421 |
+
elif len(old_eps) == 1:
|
422 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
423 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
424 |
+
elif len(old_eps) == 2:
|
425 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
426 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
427 |
+
elif len(old_eps) >= 3:
|
428 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
429 |
+
e_t_prime = (
|
430 |
+
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
431 |
+
) / 24
|
432 |
+
|
433 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
434 |
+
|
435 |
+
return x_prev, pred_x0, e_t
|
436 |
+
|
437 |
+
###### Above are original stable-diffusion code ############
|
438 |
+
|
439 |
+
###### Encode Image ########################################
|
440 |
+
|
441 |
+
@torch.no_grad()
|
442 |
+
def sample_encode_save_noise(
|
443 |
+
self,
|
444 |
+
S,
|
445 |
+
batch_size,
|
446 |
+
shape,
|
447 |
+
conditioning=None,
|
448 |
+
callback=None,
|
449 |
+
normals_sequence=None,
|
450 |
+
img_callback=None,
|
451 |
+
quantize_x0=False,
|
452 |
+
eta=0.0,
|
453 |
+
mask=None,
|
454 |
+
x0=None,
|
455 |
+
temperature=1.0,
|
456 |
+
noise_dropout=0.0,
|
457 |
+
score_corrector=None,
|
458 |
+
corrector_kwargs=None,
|
459 |
+
verbose=True,
|
460 |
+
x_T=None,
|
461 |
+
log_every_t=100,
|
462 |
+
unconditional_guidance_scale=1.0,
|
463 |
+
unconditional_conditioning=None,
|
464 |
+
input_image=None,
|
465 |
+
noise_save_path=None,
|
466 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
467 |
+
**kwargs,
|
468 |
+
):
|
469 |
+
assert conditioning is not None
|
470 |
+
# assert not isinstance(conditioning, dict)
|
471 |
+
|
472 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
473 |
+
# sampling
|
474 |
+
C, H, W = shape
|
475 |
+
size = (batch_size, C, H, W)
|
476 |
+
if verbose:
|
477 |
+
print(f"Data shape for PLMS sampling is {size}")
|
478 |
+
|
479 |
+
samples, intermediates, x0_loop = self.plms_sampling_enc_save_noise(
|
480 |
+
conditioning,
|
481 |
+
size,
|
482 |
+
callback=callback,
|
483 |
+
img_callback=img_callback,
|
484 |
+
quantize_denoised=quantize_x0,
|
485 |
+
mask=mask,
|
486 |
+
x0=x0,
|
487 |
+
ddim_use_original_steps=False,
|
488 |
+
noise_dropout=noise_dropout,
|
489 |
+
temperature=temperature,
|
490 |
+
score_corrector=score_corrector,
|
491 |
+
corrector_kwargs=corrector_kwargs,
|
492 |
+
x_T=x_T,
|
493 |
+
log_every_t=log_every_t,
|
494 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
495 |
+
unconditional_conditioning=unconditional_conditioning,
|
496 |
+
input_image=input_image,
|
497 |
+
noise_save_path=noise_save_path,
|
498 |
+
verbose=verbose
|
499 |
+
)
|
500 |
+
return samples, intermediates, x0_loop
|
501 |
+
|
502 |
+
@torch.no_grad()
|
503 |
+
def plms_sampling_enc_save_noise(
|
504 |
+
self,
|
505 |
+
cond,
|
506 |
+
shape,
|
507 |
+
x_T=None,
|
508 |
+
ddim_use_original_steps=False,
|
509 |
+
callback=None,
|
510 |
+
timesteps=None,
|
511 |
+
quantize_denoised=False,
|
512 |
+
mask=None,
|
513 |
+
x0=None,
|
514 |
+
img_callback=None,
|
515 |
+
log_every_t=100,
|
516 |
+
temperature=1.0,
|
517 |
+
noise_dropout=0.0,
|
518 |
+
score_corrector=None,
|
519 |
+
corrector_kwargs=None,
|
520 |
+
unconditional_guidance_scale=1.0,
|
521 |
+
unconditional_conditioning=None,
|
522 |
+
input_image=None,
|
523 |
+
noise_save_path=None,
|
524 |
+
verbose=True,
|
525 |
+
):
|
526 |
+
device = self.model.betas.device
|
527 |
+
|
528 |
+
b = shape[0]
|
529 |
+
if x_T is None:
|
530 |
+
img = torch.randn(shape, device=device)
|
531 |
+
else:
|
532 |
+
img = x_T
|
533 |
+
|
534 |
+
if timesteps is None:
|
535 |
+
timesteps = (
|
536 |
+
self.ddpm_num_timesteps
|
537 |
+
if ddim_use_original_steps
|
538 |
+
else self.ddim_timesteps
|
539 |
+
)
|
540 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
541 |
+
subset_end = (
|
542 |
+
int(
|
543 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
544 |
+
* self.ddim_timesteps.shape[0]
|
545 |
+
)
|
546 |
+
- 1
|
547 |
+
)
|
548 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
549 |
+
|
550 |
+
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
551 |
+
time_range = (
|
552 |
+
list(reversed(range(0, timesteps)))
|
553 |
+
if ddim_use_original_steps
|
554 |
+
else np.flip(timesteps)
|
555 |
+
)
|
556 |
+
time_range = list(range(0, timesteps)) if ddim_use_original_steps else timesteps
|
557 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
558 |
+
if verbose:
|
559 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
560 |
+
iterator = tqdm(time_range[:-1], desc='PLMS Sampler', total=total_steps)
|
561 |
+
else:
|
562 |
+
iterator = time_range[:-1]
|
563 |
+
old_eps = []
|
564 |
+
noise_images = []
|
565 |
+
for each_time in time_range:
|
566 |
+
noised_image = self.model.q_sample(
|
567 |
+
input_image, torch.tensor([each_time]).to(device)
|
568 |
+
)
|
569 |
+
noise_images.append(noised_image)
|
570 |
+
# torch.save(noised_image, noise_save_path + "_image_time%d.pt" % (each_time))
|
571 |
+
# import pudb; pudb.set_trace()
|
572 |
+
x0_loop = input_image.clone()
|
573 |
+
alphas = (
|
574 |
+
self.model.alphas_cumprod if ddim_use_original_steps else self.ddim_alphas
|
575 |
+
)
|
576 |
+
alphas_prev = (
|
577 |
+
self.model.alphas_cumprod_prev
|
578 |
+
if ddim_use_original_steps
|
579 |
+
else self.ddim_alphas_prev
|
580 |
+
)
|
581 |
+
sqrt_one_minus_alphas = (
|
582 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
583 |
+
if ddim_use_original_steps
|
584 |
+
else self.ddim_sqrt_one_minus_alphas
|
585 |
+
)
|
586 |
+
sigmas = (
|
587 |
+
self.model.ddim_sigmas_for_original_num_steps
|
588 |
+
if ddim_use_original_steps
|
589 |
+
else self.ddim_sigmas
|
590 |
+
)
|
591 |
+
|
592 |
+
def get_model_output(x, t):
|
593 |
+
x_in = torch.cat([x] * 2)
|
594 |
+
t_in = torch.cat([t] * 2)
|
595 |
+
if isinstance(cond, dict):
|
596 |
+
c_in = {key: [torch.cat([unconditional_conditioning[key][0], cond[key][0]])] for key in cond}
|
597 |
+
else:
|
598 |
+
c_in = torch.cat([unconditional_conditioning, cond])
|
599 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
600 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
601 |
+
return e_t
|
602 |
+
|
603 |
+
def get_x_prev_and_pred_x0(e_t, index, curr_x0):
|
604 |
+
# select parameters corresponding to the currently considered timestep
|
605 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
606 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
607 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
608 |
+
sqrt_one_minus_at = torch.full(
|
609 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
610 |
+
)
|
611 |
+
|
612 |
+
# current prediction for x_0
|
613 |
+
pred_x0 = (curr_x0 - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
614 |
+
|
615 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index + 1], device=device)
|
616 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index + 1], device=device)
|
617 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index + 1], device=device)
|
618 |
+
sqrt_one_minus_at = torch.full(
|
619 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index + 1], device=device
|
620 |
+
)
|
621 |
+
|
622 |
+
dir_xt = (1.0 - a_t - sigma_t ** 2).sqrt() * e_t
|
623 |
+
|
624 |
+
x_prev = a_t.sqrt() * pred_x0 + dir_xt
|
625 |
+
|
626 |
+
return x_prev, pred_x0
|
627 |
+
|
628 |
+
for i, step in enumerate(iterator):
|
629 |
+
index = i
|
630 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
631 |
+
ts_next = torch.full(
|
632 |
+
(b,),
|
633 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
634 |
+
device=device,
|
635 |
+
dtype=torch.long,
|
636 |
+
)
|
637 |
+
e_t = get_model_output(x0_loop, ts)
|
638 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index, x0_loop)
|
639 |
+
x0_loop = x_prev
|
640 |
+
# torch.save(x0_loop, noise_save_path + "_final_latent.pt")
|
641 |
+
|
642 |
+
# Reconstruction
|
643 |
+
img = x0_loop.clone()
|
644 |
+
time_range = (
|
645 |
+
list(reversed(range(0, timesteps)))
|
646 |
+
if ddim_use_original_steps
|
647 |
+
else np.flip(timesteps)
|
648 |
+
)
|
649 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
650 |
+
if verbose:
|
651 |
+
print(f"Running PLMS Sampling with {total_steps} timesteps")
|
652 |
+
iterator = tqdm(time_range, desc="PLMS Sampler", total=total_steps, miniters=total_steps+1, mininterval=600)
|
653 |
+
else:
|
654 |
+
iterator = time_range
|
655 |
+
old_eps = []
|
656 |
+
for i, step in enumerate(iterator):
|
657 |
+
index = total_steps - i - 1
|
658 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
659 |
+
ts_next = torch.full(
|
660 |
+
(b,),
|
661 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
662 |
+
device=device,
|
663 |
+
dtype=torch.long,
|
664 |
+
)
|
665 |
+
|
666 |
+
if mask is not None:
|
667 |
+
assert x0 is not None
|
668 |
+
img_orig = self.model.q_sample(
|
669 |
+
x0, ts
|
670 |
+
) # TODO: deterministic forward pass?
|
671 |
+
img = img_orig * mask + (1.0 - mask) * img
|
672 |
+
|
673 |
+
outs = self.p_sample_plms_dec_save_noise(
|
674 |
+
img,
|
675 |
+
cond,
|
676 |
+
ts,
|
677 |
+
index=index,
|
678 |
+
use_original_steps=ddim_use_original_steps,
|
679 |
+
quantize_denoised=quantize_denoised,
|
680 |
+
temperature=temperature,
|
681 |
+
noise_dropout=noise_dropout,
|
682 |
+
score_corrector=score_corrector,
|
683 |
+
corrector_kwargs=corrector_kwargs,
|
684 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
685 |
+
unconditional_conditioning=unconditional_conditioning,
|
686 |
+
old_eps=old_eps,
|
687 |
+
t_next=ts_next,
|
688 |
+
input_image=input_image,
|
689 |
+
noise_save_path=noise_save_path,
|
690 |
+
noise_image=noise_images.pop(),
|
691 |
+
)
|
692 |
+
img, pred_x0, e_t = outs
|
693 |
+
|
694 |
+
old_eps.append(e_t)
|
695 |
+
if len(old_eps) >= 4:
|
696 |
+
old_eps.pop(0)
|
697 |
+
if callback:
|
698 |
+
callback(i)
|
699 |
+
if img_callback:
|
700 |
+
img_callback(pred_x0, i)
|
701 |
+
|
702 |
+
if index % log_every_t == 0 or index == total_steps - 1:
|
703 |
+
intermediates["x_inter"].append(img)
|
704 |
+
intermediates["pred_x0"].append(pred_x0)
|
705 |
+
|
706 |
+
return img, intermediates, x0_loop
|
707 |
+
|
708 |
+
@torch.no_grad()
|
709 |
+
def p_sample_plms_dec_save_noise(
|
710 |
+
self,
|
711 |
+
x,
|
712 |
+
c1,
|
713 |
+
t,
|
714 |
+
index,
|
715 |
+
repeat_noise=False,
|
716 |
+
use_original_steps=False,
|
717 |
+
quantize_denoised=False,
|
718 |
+
temperature=1.0,
|
719 |
+
noise_dropout=0.0,
|
720 |
+
score_corrector=None,
|
721 |
+
corrector_kwargs=None,
|
722 |
+
unconditional_guidance_scale=1.0,
|
723 |
+
unconditional_conditioning=None,
|
724 |
+
old_eps=None,
|
725 |
+
t_next=None,
|
726 |
+
input_image=None,
|
727 |
+
noise_save_path=None,
|
728 |
+
noise_image=None,
|
729 |
+
):
|
730 |
+
b, *_, device = *x.shape, x.device
|
731 |
+
|
732 |
+
def get_model_output(x, t):
|
733 |
+
if (
|
734 |
+
unconditional_conditioning is None
|
735 |
+
or unconditional_guidance_scale == 1.0
|
736 |
+
):
|
737 |
+
e_t = self.model.apply_model(x, t, c1)
|
738 |
+
else:
|
739 |
+
x_in = torch.cat([x] * 2)
|
740 |
+
t_in = torch.cat([t] * 2)
|
741 |
+
if isinstance(c1, dict):
|
742 |
+
c_in = {key: [torch.cat([unconditional_conditioning[key][0], c1[key][0]])] for key in c1}
|
743 |
+
else:
|
744 |
+
c_in = torch.cat([unconditional_conditioning, c1])
|
745 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
746 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
747 |
+
return e_t
|
748 |
+
|
749 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
750 |
+
alphas_prev = (
|
751 |
+
self.model.alphas_cumprod_prev
|
752 |
+
if use_original_steps
|
753 |
+
else self.ddim_alphas_prev
|
754 |
+
)
|
755 |
+
sqrt_one_minus_alphas = (
|
756 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
757 |
+
if use_original_steps
|
758 |
+
else self.ddim_sqrt_one_minus_alphas
|
759 |
+
)
|
760 |
+
sigmas = (
|
761 |
+
self.model.ddim_sigmas_for_original_num_steps
|
762 |
+
if use_original_steps
|
763 |
+
else self.ddim_sigmas
|
764 |
+
)
|
765 |
+
|
766 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
767 |
+
# select parameters corresponding to the currently considered timestep
|
768 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
769 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
770 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
771 |
+
sqrt_one_minus_at = torch.full(
|
772 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
773 |
+
)
|
774 |
+
|
775 |
+
# current prediction for x_0
|
776 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
777 |
+
if quantize_denoised:
|
778 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
779 |
+
# direction pointing to x_t
|
780 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
781 |
+
time_curr = index * 20 + 1
|
782 |
+
# img_prev = torch.load(noise_save_path + "_image_time%d.pt" % (time_curr))
|
783 |
+
img_prev = noise_image
|
784 |
+
noise = img_prev - a_prev.sqrt() * pred_x0 - dir_xt
|
785 |
+
# torch.save(noise, noise_save_path + "_time%d.pt" % (time_curr))
|
786 |
+
|
787 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
788 |
+
return x_prev, pred_x0
|
789 |
+
|
790 |
+
e_t = get_model_output(x, t)
|
791 |
+
if len(old_eps) == 0:
|
792 |
+
# Pseudo Improved Euler (2nd order)
|
793 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
794 |
+
e_t_next = get_model_output(x_prev, t_next)
|
795 |
+
e_t_prime = (e_t + e_t_next) / 2
|
796 |
+
elif len(old_eps) == 1:
|
797 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
798 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
799 |
+
elif len(old_eps) == 2:
|
800 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
801 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
802 |
+
elif len(old_eps) >= 3:
|
803 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
804 |
+
e_t_prime = (
|
805 |
+
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
806 |
+
) / 24
|
807 |
+
|
808 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
809 |
+
|
810 |
+
return x_prev, pred_x0, e_t
|
811 |
+
|
812 |
+
################## Encode Image End ###############################
|
813 |
+
|
814 |
+
def p_sample_plms_sampling(
|
815 |
+
self,
|
816 |
+
x,
|
817 |
+
c1,
|
818 |
+
c2,
|
819 |
+
t,
|
820 |
+
index,
|
821 |
+
repeat_noise=False,
|
822 |
+
use_original_steps=False,
|
823 |
+
quantize_denoised=False,
|
824 |
+
temperature=1.0,
|
825 |
+
noise_dropout=0.0,
|
826 |
+
score_corrector=None,
|
827 |
+
corrector_kwargs=None,
|
828 |
+
unconditional_guidance_scale=1.0,
|
829 |
+
unconditional_conditioning=None,
|
830 |
+
old_eps=None,
|
831 |
+
t_next=None,
|
832 |
+
input_image=None,
|
833 |
+
optimizing_weight=None,
|
834 |
+
noise_save_path=None,
|
835 |
+
):
|
836 |
+
b, *_, device = *x.shape, x.device
|
837 |
+
|
838 |
+
def optimize_model_output(x, t):
|
839 |
+
# weight_for_pencil = torch.nn.Sigmoid()(optimizing_weight)
|
840 |
+
# condition = weight_for_pencil * c1 + (1 - weight_for_pencil) * c2
|
841 |
+
condition = optimizing_weight * c1 + (1 - optimizing_weight) * c2
|
842 |
+
if (
|
843 |
+
unconditional_conditioning is None
|
844 |
+
or unconditional_guidance_scale == 1.0
|
845 |
+
):
|
846 |
+
e_t = self.model.apply_model(x, t, condition)
|
847 |
+
else:
|
848 |
+
x_in = torch.cat([x] * 2)
|
849 |
+
t_in = torch.cat([t] * 2)
|
850 |
+
if isinstance(condition, dict):
|
851 |
+
c_in = {key: [torch.cat([unconditional_conditioning[key][0], condition[key][0]])] for key in condition}
|
852 |
+
else:
|
853 |
+
c_in = torch.cat([unconditional_conditioning, condition])
|
854 |
+
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
|
855 |
+
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
|
856 |
+
return e_t
|
857 |
+
|
858 |
+
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
859 |
+
alphas_prev = (
|
860 |
+
self.model.alphas_cumprod_prev
|
861 |
+
if use_original_steps
|
862 |
+
else self.ddim_alphas_prev
|
863 |
+
)
|
864 |
+
sqrt_one_minus_alphas = (
|
865 |
+
self.model.sqrt_one_minus_alphas_cumprod
|
866 |
+
if use_original_steps
|
867 |
+
else self.ddim_sqrt_one_minus_alphas
|
868 |
+
)
|
869 |
+
sigmas = (
|
870 |
+
self.model.ddim_sigmas_for_original_num_steps
|
871 |
+
if use_original_steps
|
872 |
+
else self.ddim_sigmas
|
873 |
+
)
|
874 |
+
|
875 |
+
def get_x_prev_and_pred_x0(e_t, index):
|
876 |
+
# select parameters corresponding to the currently considered timestep
|
877 |
+
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
878 |
+
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
879 |
+
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
880 |
+
sqrt_one_minus_at = torch.full(
|
881 |
+
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
|
882 |
+
)
|
883 |
+
|
884 |
+
# current prediction for x_0
|
885 |
+
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
886 |
+
if quantize_denoised:
|
887 |
+
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
888 |
+
# direction pointing to x_t
|
889 |
+
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
|
890 |
+
time_curr = index * 20 + 1
|
891 |
+
if noise_save_path and index > 16:
|
892 |
+
noise = torch.load(noise_save_path + "_time%d.pt" % (time_curr))[:1]
|
893 |
+
else:
|
894 |
+
noise = torch.zeros_like(dir_xt)
|
895 |
+
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
896 |
+
return x_prev, pred_x0
|
897 |
+
|
898 |
+
e_t = optimize_model_output(x, t)
|
899 |
+
if len(old_eps) == 0:
|
900 |
+
# Pseudo Improved Euler (2nd order)
|
901 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
902 |
+
# e_t_next = get_model_output(x_prev, t_next)
|
903 |
+
e_t_next = optimize_model_output(x_prev, t_next)
|
904 |
+
e_t_prime = (e_t + e_t_next) / 2
|
905 |
+
elif len(old_eps) == 1:
|
906 |
+
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
907 |
+
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
908 |
+
elif len(old_eps) == 2:
|
909 |
+
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
910 |
+
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
911 |
+
elif len(old_eps) >= 3:
|
912 |
+
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
913 |
+
e_t_prime = (
|
914 |
+
55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]
|
915 |
+
) / 24
|
916 |
+
|
917 |
+
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
918 |
+
|
919 |
+
return x_prev, pred_x0, e_t
|
920 |
+
|
921 |
+
################## Edit Input Image ###############################
|
922 |
+
|
923 |
+
def sample_optimize_intrinsic_edit(
|
924 |
+
self,
|
925 |
+
S,
|
926 |
+
batch_size,
|
927 |
+
shape,
|
928 |
+
conditioning1=None,
|
929 |
+
conditioning2=None,
|
930 |
+
callback=None,
|
931 |
+
normals_sequence=None,
|
932 |
+
img_callback=None,
|
933 |
+
quantize_x0=False,
|
934 |
+
eta=0.0,
|
935 |
+
mask=None,
|
936 |
+
x0=None,
|
937 |
+
temperature=1.0,
|
938 |
+
noise_dropout=0.0,
|
939 |
+
score_corrector=None,
|
940 |
+
corrector_kwargs=None,
|
941 |
+
verbose=True,
|
942 |
+
x_T=None,
|
943 |
+
log_every_t=100,
|
944 |
+
unconditional_guidance_scale=1.0,
|
945 |
+
unconditional_conditioning=None,
|
946 |
+
input_image=None,
|
947 |
+
noise_save_path=None,
|
948 |
+
lambda_t=None,
|
949 |
+
lambda_save_path=None,
|
950 |
+
image_save_path=None,
|
951 |
+
original_text=None,
|
952 |
+
new_text=None,
|
953 |
+
otext=None,
|
954 |
+
noise_saved_path=None,
|
955 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
956 |
+
**kwargs,
|
957 |
+
):
|
958 |
+
assert conditioning1 is not None
|
959 |
+
assert conditioning2 is not None
|
960 |
+
|
961 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
962 |
+
# sampling
|
963 |
+
C, H, W = shape
|
964 |
+
size = (batch_size, C, H, W)
|
965 |
+
print(f"Data shape for PLMS sampling is {size}")
|
966 |
+
|
967 |
+
self.plms_sampling_optimize_intrinsic_edit(
|
968 |
+
conditioning1,
|
969 |
+
conditioning2,
|
970 |
+
size,
|
971 |
+
callback=callback,
|
972 |
+
img_callback=img_callback,
|
973 |
+
quantize_denoised=quantize_x0,
|
974 |
+
mask=mask,
|
975 |
+
x0=x0,
|
976 |
+
ddim_use_original_steps=False,
|
977 |
+
noise_dropout=noise_dropout,
|
978 |
+
temperature=temperature,
|
979 |
+
score_corrector=score_corrector,
|
980 |
+
corrector_kwargs=corrector_kwargs,
|
981 |
+
x_T=x_T,
|
982 |
+
log_every_t=log_every_t,
|
983 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
984 |
+
unconditional_conditioning=unconditional_conditioning,
|
985 |
+
input_image=input_image,
|
986 |
+
noise_save_path=noise_save_path,
|
987 |
+
lambda_t=lambda_t,
|
988 |
+
lambda_save_path=lambda_save_path,
|
989 |
+
image_save_path=image_save_path,
|
990 |
+
original_text=original_text,
|
991 |
+
new_text=new_text,
|
992 |
+
otext=otext,
|
993 |
+
noise_saved_path=noise_saved_path,
|
994 |
+
)
|
995 |
+
return None
|
996 |
+
|
997 |
+
def plms_sampling_optimize_intrinsic_edit(
|
998 |
+
self,
|
999 |
+
cond1,
|
1000 |
+
cond2,
|
1001 |
+
shape,
|
1002 |
+
x_T=None,
|
1003 |
+
ddim_use_original_steps=False,
|
1004 |
+
callback=None,
|
1005 |
+
timesteps=None,
|
1006 |
+
quantize_denoised=False,
|
1007 |
+
mask=None,
|
1008 |
+
x0=None,
|
1009 |
+
img_callback=None,
|
1010 |
+
log_every_t=100,
|
1011 |
+
temperature=1.0,
|
1012 |
+
noise_dropout=0.0,
|
1013 |
+
score_corrector=None,
|
1014 |
+
corrector_kwargs=None,
|
1015 |
+
unconditional_guidance_scale=1.0,
|
1016 |
+
unconditional_conditioning=None,
|
1017 |
+
input_image=None,
|
1018 |
+
noise_save_path=None,
|
1019 |
+
lambda_t=None,
|
1020 |
+
lambda_save_path=None,
|
1021 |
+
image_save_path=None,
|
1022 |
+
original_text=None,
|
1023 |
+
new_text=None,
|
1024 |
+
otext=None,
|
1025 |
+
noise_saved_path=None,
|
1026 |
+
):
|
1027 |
+
# Different from above, the intrinsic edit version needs
|
1028 |
+
device = self.model.betas.device
|
1029 |
+
|
1030 |
+
b = shape[0]
|
1031 |
+
if x_T is None:
|
1032 |
+
img = torch.randn(shape, device=device)
|
1033 |
+
else:
|
1034 |
+
img = x_T
|
1035 |
+
img_clone = img.clone()
|
1036 |
+
|
1037 |
+
if timesteps is None:
|
1038 |
+
timesteps = (
|
1039 |
+
self.ddpm_num_timesteps
|
1040 |
+
if ddim_use_original_steps
|
1041 |
+
else self.ddim_timesteps
|
1042 |
+
)
|
1043 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
1044 |
+
subset_end = (
|
1045 |
+
int(
|
1046 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
1047 |
+
* self.ddim_timesteps.shape[0]
|
1048 |
+
)
|
1049 |
+
- 1
|
1050 |
+
)
|
1051 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
1052 |
+
|
1053 |
+
intermediates = {"x_inter": [img], "pred_x0": [img]}
|
1054 |
+
time_range = (
|
1055 |
+
list(reversed(range(0, timesteps)))
|
1056 |
+
if ddim_use_original_steps
|
1057 |
+
else np.flip(timesteps)
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
weighting_parameter = lambda_t
|
1061 |
+
weighting_parameter.requires_grad = True
|
1062 |
+
from torch import optim
|
1063 |
+
|
1064 |
+
optimizer = optim.Adam([weighting_parameter], lr=0.05)
|
1065 |
+
|
1066 |
+
print("Original image")
|
1067 |
+
with torch.no_grad():
|
1068 |
+
img = img_clone.clone()
|
1069 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
1070 |
+
iterator = time_range
|
1071 |
+
old_eps = []
|
1072 |
+
|
1073 |
+
for i, step in enumerate(iterator):
|
1074 |
+
index = total_steps - i - 1
|
1075 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
1076 |
+
ts_next = torch.full(
|
1077 |
+
(b,),
|
1078 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
1079 |
+
device=device,
|
1080 |
+
dtype=torch.long,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
outs = self.p_sample_plms_sampling(
|
1084 |
+
img,
|
1085 |
+
cond1,
|
1086 |
+
cond2,
|
1087 |
+
ts,
|
1088 |
+
index=index,
|
1089 |
+
use_original_steps=ddim_use_original_steps,
|
1090 |
+
quantize_denoised=quantize_denoised,
|
1091 |
+
temperature=temperature,
|
1092 |
+
noise_dropout=noise_dropout,
|
1093 |
+
score_corrector=score_corrector,
|
1094 |
+
corrector_kwargs=corrector_kwargs,
|
1095 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1096 |
+
unconditional_conditioning=unconditional_conditioning,
|
1097 |
+
old_eps=old_eps,
|
1098 |
+
t_next=ts_next,
|
1099 |
+
input_image=input_image,
|
1100 |
+
optimizing_weight=torch.ones(50)[i],
|
1101 |
+
noise_save_path=noise_saved_path,
|
1102 |
+
)
|
1103 |
+
img, pred_x0, e_t = outs
|
1104 |
+
old_eps.append(e_t)
|
1105 |
+
if len(old_eps) >= 4:
|
1106 |
+
old_eps.pop(0)
|
1107 |
+
img_temp = self.model.decode_first_stage(img)
|
1108 |
+
img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
|
1109 |
+
img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2)
|
1110 |
+
# save image
|
1111 |
+
with torch.no_grad():
|
1112 |
+
x_sample = 255.0 * rearrange(
|
1113 |
+
img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
|
1114 |
+
)
|
1115 |
+
imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1116 |
+
imgsave.save(image_save_path + "original.png")
|
1117 |
+
readed_image = (
|
1118 |
+
torchvision.io.read_image(image_save_path + "original.png").float()
|
1119 |
+
/ 255
|
1120 |
+
)
|
1121 |
+
print("Optimizing start")
|
1122 |
+
for epoch in tqdm(range(10)):
|
1123 |
+
img = img_clone.clone()
|
1124 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
1125 |
+
iterator = time_range
|
1126 |
+
old_eps = []
|
1127 |
+
|
1128 |
+
for i, step in enumerate(iterator):
|
1129 |
+
index = total_steps - i - 1
|
1130 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
1131 |
+
ts_next = torch.full(
|
1132 |
+
(b,),
|
1133 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
1134 |
+
device=device,
|
1135 |
+
dtype=torch.long,
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
outs = self.p_sample_plms_sampling(
|
1139 |
+
img,
|
1140 |
+
cond1,
|
1141 |
+
cond2,
|
1142 |
+
ts,
|
1143 |
+
index=index,
|
1144 |
+
use_original_steps=ddim_use_original_steps,
|
1145 |
+
quantize_denoised=quantize_denoised,
|
1146 |
+
temperature=temperature,
|
1147 |
+
noise_dropout=noise_dropout,
|
1148 |
+
score_corrector=score_corrector,
|
1149 |
+
corrector_kwargs=corrector_kwargs,
|
1150 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1151 |
+
unconditional_conditioning=unconditional_conditioning,
|
1152 |
+
old_eps=old_eps,
|
1153 |
+
t_next=ts_next,
|
1154 |
+
input_image=input_image,
|
1155 |
+
optimizing_weight=weighting_parameter[i],
|
1156 |
+
noise_save_path=noise_saved_path,
|
1157 |
+
)
|
1158 |
+
img, pred_x0, e_t = outs
|
1159 |
+
old_eps.append(e_t)
|
1160 |
+
if len(old_eps) >= 4:
|
1161 |
+
old_eps.pop(0)
|
1162 |
+
img_temp = self.model.decode_first_stage(img)
|
1163 |
+
img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
|
1164 |
+
img_temp_ddim = img_temp_ddim.cpu()
|
1165 |
+
|
1166 |
+
# save image
|
1167 |
+
# with torch.no_grad():
|
1168 |
+
# x_sample = 255.0 * rearrange(
|
1169 |
+
# img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
|
1170 |
+
# )
|
1171 |
+
# imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1172 |
+
# imgsave.save(image_save_path + "/%d.png" % (epoch))
|
1173 |
+
|
1174 |
+
loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image)
|
1175 |
+
loss2 = DCLIPLoss()(
|
1176 |
+
readed_image, img_temp_ddim[0].float().cuda(), otext, new_text
|
1177 |
+
)
|
1178 |
+
loss = 0.05 * loss1 + loss2
|
1179 |
+
optimizer.zero_grad()
|
1180 |
+
loss.backward()
|
1181 |
+
optimizer.step()
|
1182 |
+
# torch.save(
|
1183 |
+
# weighting_parameter, lambda_save_path + "/weightingParam%d.pt" % (epoch)
|
1184 |
+
# )
|
1185 |
+
if epoch < 9:
|
1186 |
+
del img
|
1187 |
+
else:
|
1188 |
+
# save image
|
1189 |
+
with torch.no_grad():
|
1190 |
+
x_sample = 255.0 * rearrange(
|
1191 |
+
img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
|
1192 |
+
)
|
1193 |
+
imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1194 |
+
imgsave.save(image_save_path + "/final.png")
|
1195 |
+
torch.save(
|
1196 |
+
weighting_parameter, lambda_save_path + "/weightingParam_final.pt"
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
torch.cuda.empty_cache()
|
1200 |
+
# shutil.rmtree("noise")
|
1201 |
+
return None
|
1202 |
+
|
1203 |
+
################ Edit Image End ######################
|
1204 |
+
|
1205 |
+
################ Disentangle #########################
|
1206 |
+
|
1207 |
+
def sample_optimize_intrinsic(
|
1208 |
+
self,
|
1209 |
+
S,
|
1210 |
+
batch_size,
|
1211 |
+
shape,
|
1212 |
+
conditioning1=None,
|
1213 |
+
conditioning2=None,
|
1214 |
+
callback=None,
|
1215 |
+
normals_sequence=None,
|
1216 |
+
img_callback=None,
|
1217 |
+
quantize_x0=False,
|
1218 |
+
eta=0.0,
|
1219 |
+
mask=None,
|
1220 |
+
x0=None,
|
1221 |
+
temperature=1.0,
|
1222 |
+
noise_dropout=0.0,
|
1223 |
+
score_corrector=None,
|
1224 |
+
corrector_kwargs=None,
|
1225 |
+
verbose=True,
|
1226 |
+
x_T=None,
|
1227 |
+
log_every_t=100,
|
1228 |
+
unconditional_guidance_scale=1.0,
|
1229 |
+
unconditional_conditioning=None,
|
1230 |
+
input_image=None,
|
1231 |
+
noise_save_path=None,
|
1232 |
+
lambda_t=None,
|
1233 |
+
lambda_save_path=None,
|
1234 |
+
image_save_path=None,
|
1235 |
+
original_text=None,
|
1236 |
+
new_text=None,
|
1237 |
+
otext=None,
|
1238 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
1239 |
+
**kwargs,
|
1240 |
+
):
|
1241 |
+
assert conditioning1 is not None
|
1242 |
+
assert conditioning2 is not None
|
1243 |
+
|
1244 |
+
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
1245 |
+
# sampling
|
1246 |
+
C, H, W = shape
|
1247 |
+
size = (batch_size, C, H, W)
|
1248 |
+
print(f"Data shape for PLMS sampling is {size}")
|
1249 |
+
|
1250 |
+
self.plms_sampling_optimize_intrinsic(
|
1251 |
+
conditioning1,
|
1252 |
+
conditioning2,
|
1253 |
+
size,
|
1254 |
+
callback=callback,
|
1255 |
+
img_callback=img_callback,
|
1256 |
+
quantize_denoised=quantize_x0,
|
1257 |
+
mask=mask,
|
1258 |
+
x0=x0,
|
1259 |
+
ddim_use_original_steps=False,
|
1260 |
+
noise_dropout=noise_dropout,
|
1261 |
+
temperature=temperature,
|
1262 |
+
score_corrector=score_corrector,
|
1263 |
+
corrector_kwargs=corrector_kwargs,
|
1264 |
+
x_T=x_T,
|
1265 |
+
log_every_t=log_every_t,
|
1266 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1267 |
+
unconditional_conditioning=unconditional_conditioning,
|
1268 |
+
input_image=input_image,
|
1269 |
+
noise_save_path=noise_save_path,
|
1270 |
+
lambda_t=lambda_t,
|
1271 |
+
lambda_save_path=lambda_save_path,
|
1272 |
+
image_save_path=image_save_path,
|
1273 |
+
original_text=original_text,
|
1274 |
+
new_text=new_text,
|
1275 |
+
otext=otext,
|
1276 |
+
)
|
1277 |
+
return None
|
1278 |
+
|
1279 |
+
def plms_sampling_optimize_intrinsic(
|
1280 |
+
self,
|
1281 |
+
cond1,
|
1282 |
+
cond2,
|
1283 |
+
shape,
|
1284 |
+
x_T=None,
|
1285 |
+
ddim_use_original_steps=False,
|
1286 |
+
callback=None,
|
1287 |
+
timesteps=None,
|
1288 |
+
quantize_denoised=False,
|
1289 |
+
mask=None,
|
1290 |
+
x0=None,
|
1291 |
+
img_callback=None,
|
1292 |
+
log_every_t=100,
|
1293 |
+
temperature=1.0,
|
1294 |
+
noise_dropout=0.0,
|
1295 |
+
score_corrector=None,
|
1296 |
+
corrector_kwargs=None,
|
1297 |
+
unconditional_guidance_scale=1.0,
|
1298 |
+
unconditional_conditioning=None,
|
1299 |
+
input_image=None,
|
1300 |
+
noise_save_path=None,
|
1301 |
+
lambda_t=None,
|
1302 |
+
lambda_save_path=None,
|
1303 |
+
image_save_path=None,
|
1304 |
+
original_text=None,
|
1305 |
+
new_text=None,
|
1306 |
+
otext=None,
|
1307 |
+
):
|
1308 |
+
device = self.model.betas.device
|
1309 |
+
|
1310 |
+
b = shape[0]
|
1311 |
+
if x_T is None:
|
1312 |
+
img = torch.randn(shape, device=device)
|
1313 |
+
else:
|
1314 |
+
img = x_T
|
1315 |
+
img_clone = img.clone()
|
1316 |
+
|
1317 |
+
if timesteps is None:
|
1318 |
+
timesteps = (
|
1319 |
+
self.ddpm_num_timesteps
|
1320 |
+
if ddim_use_original_steps
|
1321 |
+
else self.ddim_timesteps
|
1322 |
+
)
|
1323 |
+
elif timesteps is not None and not ddim_use_original_steps:
|
1324 |
+
subset_end = (
|
1325 |
+
int(
|
1326 |
+
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
1327 |
+
* self.ddim_timesteps.shape[0]
|
1328 |
+
)
|
1329 |
+
- 1
|
1330 |
+
)
|
1331 |
+
timesteps = self.ddim_timesteps[:subset_end]
|
1332 |
+
|
1333 |
+
time_range = (
|
1334 |
+
list(reversed(range(0, timesteps)))
|
1335 |
+
if ddim_use_original_steps
|
1336 |
+
else np.flip(timesteps)
|
1337 |
+
)
|
1338 |
+
weighting_parameter = lambda_t
|
1339 |
+
weighting_parameter.requires_grad = True
|
1340 |
+
from torch import optim
|
1341 |
+
|
1342 |
+
optimizer = optim.Adam([weighting_parameter], lr=0.05)
|
1343 |
+
|
1344 |
+
print("Original image")
|
1345 |
+
with torch.no_grad():
|
1346 |
+
img = img_clone.clone()
|
1347 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
1348 |
+
iterator = time_range
|
1349 |
+
old_eps = []
|
1350 |
+
|
1351 |
+
for i, step in enumerate(iterator):
|
1352 |
+
index = total_steps - i - 1
|
1353 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
1354 |
+
ts_next = torch.full(
|
1355 |
+
(b,),
|
1356 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
1357 |
+
device=device,
|
1358 |
+
dtype=torch.long,
|
1359 |
+
)
|
1360 |
+
|
1361 |
+
outs = self.p_sample_plms_sampling(
|
1362 |
+
img,
|
1363 |
+
cond1,
|
1364 |
+
cond2,
|
1365 |
+
ts,
|
1366 |
+
index=index,
|
1367 |
+
use_original_steps=ddim_use_original_steps,
|
1368 |
+
quantize_denoised=quantize_denoised,
|
1369 |
+
temperature=temperature,
|
1370 |
+
noise_dropout=noise_dropout,
|
1371 |
+
score_corrector=score_corrector,
|
1372 |
+
corrector_kwargs=corrector_kwargs,
|
1373 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1374 |
+
unconditional_conditioning=unconditional_conditioning,
|
1375 |
+
old_eps=old_eps,
|
1376 |
+
t_next=ts_next,
|
1377 |
+
input_image=input_image,
|
1378 |
+
optimizing_weight=torch.ones(50)[i],
|
1379 |
+
noise_save_path=noise_save_path,
|
1380 |
+
)
|
1381 |
+
img, pred_x0, e_t = outs
|
1382 |
+
old_eps.append(e_t)
|
1383 |
+
if len(old_eps) >= 4:
|
1384 |
+
old_eps.pop(0)
|
1385 |
+
img_temp = self.model.decode_first_stage(img)
|
1386 |
+
del img
|
1387 |
+
img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
|
1388 |
+
img_temp_ddim = img_temp_ddim.cpu().permute(0, 2, 3, 1).permute(0, 3, 1, 2)
|
1389 |
+
# save image
|
1390 |
+
with torch.no_grad():
|
1391 |
+
x_sample = 255.0 * rearrange(
|
1392 |
+
img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
|
1393 |
+
)
|
1394 |
+
imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1395 |
+
imgsave.save(image_save_path + "original.png")
|
1396 |
+
|
1397 |
+
readed_image = (
|
1398 |
+
torchvision.io.read_image(image_save_path + "original.png").float()
|
1399 |
+
/ 255
|
1400 |
+
)
|
1401 |
+
|
1402 |
+
print("Optimizing start")
|
1403 |
+
for epoch in tqdm(range(10)):
|
1404 |
+
img = img_clone.clone()
|
1405 |
+
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
|
1406 |
+
iterator = time_range
|
1407 |
+
old_eps = []
|
1408 |
+
|
1409 |
+
for i, step in enumerate(iterator):
|
1410 |
+
index = total_steps - i - 1
|
1411 |
+
ts = torch.full((b,), step, device=device, dtype=torch.long)
|
1412 |
+
ts_next = torch.full(
|
1413 |
+
(b,),
|
1414 |
+
time_range[min(i + 1, len(time_range) - 1)],
|
1415 |
+
device=device,
|
1416 |
+
dtype=torch.long,
|
1417 |
+
)
|
1418 |
+
|
1419 |
+
outs = self.p_sample_plms_sampling(
|
1420 |
+
img,
|
1421 |
+
cond1,
|
1422 |
+
cond2,
|
1423 |
+
ts,
|
1424 |
+
index=index,
|
1425 |
+
use_original_steps=ddim_use_original_steps,
|
1426 |
+
quantize_denoised=quantize_denoised,
|
1427 |
+
temperature=temperature,
|
1428 |
+
noise_dropout=noise_dropout,
|
1429 |
+
score_corrector=score_corrector,
|
1430 |
+
corrector_kwargs=corrector_kwargs,
|
1431 |
+
unconditional_guidance_scale=unconditional_guidance_scale,
|
1432 |
+
unconditional_conditioning=unconditional_conditioning,
|
1433 |
+
old_eps=old_eps,
|
1434 |
+
t_next=ts_next,
|
1435 |
+
input_image=input_image,
|
1436 |
+
optimizing_weight=weighting_parameter[i],
|
1437 |
+
noise_save_path=noise_save_path,
|
1438 |
+
)
|
1439 |
+
img, _, e_t = outs
|
1440 |
+
old_eps.append(e_t)
|
1441 |
+
if len(old_eps) >= 4:
|
1442 |
+
old_eps.pop(0)
|
1443 |
+
img_temp = self.model.decode_first_stage(img)
|
1444 |
+
del img
|
1445 |
+
img_temp_ddim = torch.clamp((img_temp + 1.0) / 2.0, min=0.0, max=1.0)
|
1446 |
+
img_temp_ddim = img_temp_ddim.cpu()
|
1447 |
+
|
1448 |
+
# # save image
|
1449 |
+
# with torch.no_grad():
|
1450 |
+
# x_sample = 255. * rearrange(img_temp_ddim[0].detach().cpu().numpy(), 'c h w -> h w c')
|
1451 |
+
# imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1452 |
+
# imgsave.save(image_save_path + "/%d.png"%(epoch))
|
1453 |
+
|
1454 |
+
loss1 = VGGPerceptualLoss()(img_temp_ddim[0], readed_image)
|
1455 |
+
loss2 = DCLIPLoss()(
|
1456 |
+
readed_image, img_temp_ddim[0].float().cuda(), otext, new_text
|
1457 |
+
)
|
1458 |
+
loss = (
|
1459 |
+
0.05 * loss1 + loss2
|
1460 |
+
) # 0.05 or 0.03. Adjust according to attributes on scenes or people.
|
1461 |
+
optimizer.zero_grad()
|
1462 |
+
loss.backward()
|
1463 |
+
optimizer.step()
|
1464 |
+
# torch.save(weighting_parameter, lambda_save_path+"/weightingParam%d.pt"%(epoch))
|
1465 |
+
with torch.no_grad():
|
1466 |
+
if epoch == 9:
|
1467 |
+
# save image
|
1468 |
+
x_sample = 255.0 * rearrange(
|
1469 |
+
img_temp_ddim[0].detach().cpu().numpy(), "c h w -> h w c"
|
1470 |
+
)
|
1471 |
+
imgsave = Image.fromarray(x_sample.astype(np.uint8))
|
1472 |
+
imgsave.save(image_save_path + "/final.png")
|
1473 |
+
torch.save(
|
1474 |
+
weighting_parameter,
|
1475 |
+
lambda_save_path + "/weightingParam_final.pt",
|
1476 |
+
)
|
1477 |
+
torch.cuda.empty_cache()
|
1478 |
+
return None
|
1479 |
+
|
1480 |
+
|
1481 |
+
################ Disentangle End #########################
|
cldm/tmp.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
use kornia and albumentations for transformations
|
5 |
+
@author: Tu Bui @University of Surrey
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from . import utils
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from PIL import Image
|
14 |
+
import kornia as ko
|
15 |
+
import albumentations as ab
|
16 |
+
|
17 |
+
|
18 |
+
class IdentityAugment(nn.Module):
|
19 |
+
def __init__(self):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
def forward(self, x, **kwargs):
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
class RandomCompress(nn.Module):
|
27 |
+
def __init__(self, severity='medium', p=0.5):
|
28 |
+
super().__init__()
|
29 |
+
self.p = p
|
30 |
+
if severity == 'low':
|
31 |
+
self.jpeg_quality = 70
|
32 |
+
elif severity == 'medium':
|
33 |
+
self.jpeg_quality = 50
|
34 |
+
elif severity == 'high':
|
35 |
+
self.jpeg_quality = 40
|
36 |
+
|
37 |
+
def forward(self, x, ramp=1.):
|
38 |
+
# x (B, C, H, W) in range [0, 1]
|
39 |
+
# ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
|
40 |
+
if torch.rand(1)[0] >= self.p:
|
41 |
+
return x
|
42 |
+
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
|
43 |
+
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
class RandomBoxBlur(nn.Module):
|
48 |
+
def __init__(self, severity='medium', border_type='reflect', normalize=True, p=0.5):
|
49 |
+
super().__init__()
|
50 |
+
self.p = p
|
51 |
+
if severity == 'low':
|
52 |
+
kernel_size = 3
|
53 |
+
elif severity == 'medium':
|
54 |
+
kernel_size = 5
|
55 |
+
elif severity == 'high':
|
56 |
+
kernel_size = 7
|
57 |
+
|
58 |
+
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalize=normalize, p=self.p)
|
59 |
+
|
60 |
+
def forward(self, x, **kwargs):
|
61 |
+
return self.tform(x)
|
62 |
+
|
63 |
+
class RandomMedianBlur(nn.Module):
|
64 |
+
def __init__(self, severity='medium', p=0.5):
|
65 |
+
super().__init__()
|
66 |
+
self.p = p
|
67 |
+
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
|
68 |
+
|
69 |
+
def forward(self, x, **kwargs):
|
70 |
+
return self.tform(x)
|
71 |
+
|
72 |
+
|
73 |
+
class RandomBrightness(nn.Module):
|
74 |
+
def __init__(self, severity='medium', p=0.5):
|
75 |
+
super().__init__()
|
76 |
+
self.p = p
|
77 |
+
if severity == 'low':
|
78 |
+
brightness = (0.9, 1.1)
|
79 |
+
elif severity == 'medium':
|
80 |
+
brightness = (0.75, 1.25)
|
81 |
+
elif severity == 'high':
|
82 |
+
brightness = (0.5, 1.5)
|
83 |
+
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
|
84 |
+
|
85 |
+
def forward(self, x, **kwargs):
|
86 |
+
return self.tform(x)
|
87 |
+
|
88 |
+
|
89 |
+
class RandomContrast(nn.Module):
|
90 |
+
def __init__(self, severity='medium', p=0.5):
|
91 |
+
super().__init__()
|
92 |
+
self.p = p
|
93 |
+
if severity == 'low':
|
94 |
+
contrast = (0.9, 1.1)
|
95 |
+
elif severity == 'medium':
|
96 |
+
contrast = (0.75, 1.25)
|
97 |
+
elif severity == 'high':
|
98 |
+
contrast = (0.5, 1.5)
|
99 |
+
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
|
100 |
+
|
101 |
+
def forward(self, x, **kwargs):
|
102 |
+
return self.tform(x)
|
103 |
+
|
104 |
+
|
105 |
+
class RandomSaturation(nn.Module):
|
106 |
+
def __init__(self, severity='medium', p=0.5):
|
107 |
+
super().__init__()
|
108 |
+
self.p = p
|
109 |
+
if severity == 'low':
|
110 |
+
sat = (0.9, 1.1)
|
111 |
+
elif severity == 'medium':
|
112 |
+
sat = (0.75, 1.25)
|
113 |
+
elif severity == 'high':
|
114 |
+
sat = (0.5, 1.5)
|
115 |
+
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
|
116 |
+
|
117 |
+
def forward(self, x, **kwargs):
|
118 |
+
return self.tform(x)
|
119 |
+
|
120 |
+
class RandomSharpness(nn.Module):
|
121 |
+
def __init__(self, severity='medium', p=0.5):
|
122 |
+
super().__init__()
|
123 |
+
self.p = p
|
124 |
+
if severity == 'low':
|
125 |
+
sharpness = 0.5
|
126 |
+
elif severity == 'medium':
|
127 |
+
sharpness = 1.0
|
128 |
+
elif severity == 'high':
|
129 |
+
sharpness = 2.5
|
130 |
+
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
|
131 |
+
|
132 |
+
def forward(self, x, **kwargs):
|
133 |
+
return self.tform(x)
|
134 |
+
|
135 |
+
class RandomColorJiggle(nn.Module):
|
136 |
+
def __init__(self, severity='medium', p=0.5):
|
137 |
+
super().__init__()
|
138 |
+
self.p = p
|
139 |
+
if severity == 'low':
|
140 |
+
factor = (0.05, 0.05, 0.05, 0.01)
|
141 |
+
elif severity == 'medium':
|
142 |
+
factor = (0.1, 0.1, 0.1, 0.02)
|
143 |
+
elif severity == 'high':
|
144 |
+
factor = (0.1, 0.1, 0.1, 0.05)
|
145 |
+
self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
|
146 |
+
|
147 |
+
def forward(self, x, **kwargs):
|
148 |
+
return self.tform(x)
|
149 |
+
|
150 |
+
class RandomHue(nn.Module):
|
151 |
+
def __init__(self, severity='medium', p=0.5):
|
152 |
+
super().__init__()
|
153 |
+
self.p = p
|
154 |
+
if severity == 'low':
|
155 |
+
hue = 0.01
|
156 |
+
elif severity == 'medium':
|
157 |
+
hue = 0.02
|
158 |
+
elif severity == 'high':
|
159 |
+
hue = 0.05
|
160 |
+
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
|
161 |
+
|
162 |
+
def forward(self, x, **kwargs):
|
163 |
+
return self.tform(x)
|
164 |
+
|
165 |
+
class RandomGamma(nn.Module):
|
166 |
+
def __init__(self, severity='medium', p=0.5):
|
167 |
+
super().__init__()
|
168 |
+
self.p = p
|
169 |
+
if severity == 'low':
|
170 |
+
gamma, gain = (0.9, 1.1), (0.9,1.1)
|
171 |
+
elif severity == 'medium':
|
172 |
+
gamma, gain = (0.75, 1.25), (0.75,1.25)
|
173 |
+
elif severity == 'high':
|
174 |
+
gamma, gain = (0.5, 1.5), (0.5,1.5)
|
175 |
+
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
|
176 |
+
|
177 |
+
def forward(self, x, **kwargs):
|
178 |
+
return self.tform(x)
|
179 |
+
|
180 |
+
class RandomGaussianBlur(nn.Module):
|
181 |
+
def __init__(self, severity='medium', p=0.5):
|
182 |
+
super().__init__()
|
183 |
+
self.p = p
|
184 |
+
if severity == 'low':
|
185 |
+
kernel_size, sigma = 3, (0.1, 1.0)
|
186 |
+
elif severity == 'medium':
|
187 |
+
kernel_size, sigma = 5, (0.1, 1.5)
|
188 |
+
elif severity == 'high':
|
189 |
+
kernel_size, sigma = 7, (0.1, 2.0)
|
190 |
+
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
|
191 |
+
|
192 |
+
def forward(self, x, **kwargs):
|
193 |
+
return self.tform(x)
|
194 |
+
|
195 |
+
class RandomGaussianNoise(nn.Module):
|
196 |
+
def __init__(self, severity='medium', p=0.5):
|
197 |
+
super().__init__()
|
198 |
+
self.p = p
|
199 |
+
if severity == 'low':
|
200 |
+
std = 0.02
|
201 |
+
elif severity == 'medium':
|
202 |
+
std = 0.04
|
203 |
+
elif severity == 'high':
|
204 |
+
std = 0.08
|
205 |
+
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
|
206 |
+
|
207 |
+
def forward(self, x, **kwargs):
|
208 |
+
return self.tform(x)
|
209 |
+
|
210 |
+
class RandomMotionBlur(nn.Module):
|
211 |
+
def __init__(self, severity='medium', p=0.5):
|
212 |
+
super().__init__()
|
213 |
+
self.p = p
|
214 |
+
if severity == 'low':
|
215 |
+
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
|
216 |
+
elif severity == 'medium':
|
217 |
+
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
|
218 |
+
elif severity == 'high':
|
219 |
+
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
|
220 |
+
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
|
221 |
+
|
222 |
+
def forward(self, x, **kwargs):
|
223 |
+
return self.tform(x)
|
224 |
+
|
225 |
+
class RandomPosterize(nn.Module):
|
226 |
+
def __init__(self, severity='medium', p=0.5):
|
227 |
+
super().__init__()
|
228 |
+
self.p = p
|
229 |
+
if severity == 'low':
|
230 |
+
bits = 5
|
231 |
+
elif severity == 'medium':
|
232 |
+
bits = 4
|
233 |
+
elif severity == 'high':
|
234 |
+
bits = 3
|
235 |
+
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
|
236 |
+
|
237 |
+
def forward(self, x, **kwargs):
|
238 |
+
return self.tform(x)
|
239 |
+
|
240 |
+
class RandomRGBShift(nn.Module):
|
241 |
+
def __init__(self, severity='medium', p=0.5):
|
242 |
+
super().__init__()
|
243 |
+
self.p = p
|
244 |
+
if severity == 'low':
|
245 |
+
rgb = 0.02
|
246 |
+
elif severity == 'medium':
|
247 |
+
rgb = 0.05
|
248 |
+
elif severity == 'high':
|
249 |
+
rgb = 0.1
|
250 |
+
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
|
251 |
+
|
252 |
+
def forward(self, x, **kwargs):
|
253 |
+
return self.tform(x)
|
254 |
+
|
255 |
+
|
256 |
+
|
257 |
+
class TransformNet(nn.Module):
|
258 |
+
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=True, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
|
259 |
+
super().__init__()
|
260 |
+
self.n_optional = n_optional
|
261 |
+
self.p = p
|
262 |
+
p_flip = 0.5 if flip else 0
|
263 |
+
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
|
264 |
+
self.ramp = ramp
|
265 |
+
self.register_buffer('step0', torch.tensor(0))
|
266 |
+
|
267 |
+
assert crop_mode in ['random_crop', 'resized_crop']
|
268 |
+
if crop_mode == 'random_crop':
|
269 |
+
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
|
270 |
+
elif crop_mode == 'resized_crop':
|
271 |
+
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
|
272 |
+
|
273 |
+
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
|
274 |
+
self.optional_transforms = []
|
275 |
+
if compress:
|
276 |
+
self.optional_transforms.append(RandomCompress(severity, p=p))
|
277 |
+
if brightness:
|
278 |
+
self.optional_transforms.append(RandomBrightness(severity, p=p))
|
279 |
+
if contrast:
|
280 |
+
self.optional_transforms.append(RandomContrast(severity, p=p))
|
281 |
+
if color_jiggle:
|
282 |
+
self.optional_transforms.append(RandomColorJiggle(severity, p=p))
|
283 |
+
if gamma:
|
284 |
+
self.optional_transforms.append(RandomGamma(severity, p=p))
|
285 |
+
if grayscale:
|
286 |
+
self.optional_transforms.append(ko.augmentation.RandomGrayscale(p=p/4))
|
287 |
+
if gaussian_blur:
|
288 |
+
self.optional_transforms.append(RandomGaussianBlur(severity, p=p))
|
289 |
+
if gaussian_noise:
|
290 |
+
self.optional_transforms.append(RandomGaussianNoise(severity, p=p))
|
291 |
+
if hue:
|
292 |
+
self.optional_transforms.append(RandomHue(severity, p=p))
|
293 |
+
if motion_blur:
|
294 |
+
self.optional_transforms.append(RandomMotionBlur(severity, p=p))
|
295 |
+
if posterize:
|
296 |
+
self.optional_transforms.append(RandomPosterize(severity, p=p))
|
297 |
+
if rgb_shift:
|
298 |
+
self.optional_transforms.append(RandomRGBShift(severity, p=p))
|
299 |
+
if saturation:
|
300 |
+
self.optional_transforms.append(RandomSaturation(severity, p=p))
|
301 |
+
if sharpness:
|
302 |
+
self.optional_transforms.append(RandomSharpness(severity, p=p))
|
303 |
+
if median_blur:
|
304 |
+
self.optional_transforms.append(RandomMedianBlur(severity, p=p))
|
305 |
+
|
306 |
+
def activate(self, global_step):
|
307 |
+
if self.step0 == 0:
|
308 |
+
print(f'[TRAINING] Activating TransformNet at step {global_step}')
|
309 |
+
self.step0 = torch.tensor(global_step)
|
310 |
+
|
311 |
+
def is_activated(self):
|
312 |
+
return self.step0 > 0
|
313 |
+
|
314 |
+
def forward(self, x, global_step, p=0.9):
|
315 |
+
# x: [batch_size, 3, H, W] in range [-1, 1]
|
316 |
+
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
|
317 |
+
# fixed transforms
|
318 |
+
for tform in self.fixed_transforms:
|
319 |
+
x = tform(x)
|
320 |
+
if isinstance(x, tuple):
|
321 |
+
x = x[0]
|
322 |
+
|
323 |
+
# optional transforms
|
324 |
+
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
|
325 |
+
try:
|
326 |
+
if len(self.optional_transforms) > 0:
|
327 |
+
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
|
328 |
+
for tform_id in tform_ids:
|
329 |
+
tform = self.optional_transforms[tform_id]
|
330 |
+
x = tform(x, ramp=ramp)
|
331 |
+
if isinstance(x, tuple):
|
332 |
+
x = x[0]
|
333 |
+
except Exception as e:
|
334 |
+
print(tform_id, ramp)
|
335 |
+
import pdb; pdb.set_trace()
|
336 |
+
return x * 2 - 1 # [0, 1] -> [-1, 1]
|
337 |
+
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
pass
|
cldm/transformations.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from . import utils
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from torch import nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from tools.augment_imagenetc import RandomImagenetC
|
8 |
+
from PIL import Image
|
9 |
+
import kornia as ko
|
10 |
+
# from kornia.augmentation import RandomHorizontalFlip, RandomCrop
|
11 |
+
|
12 |
+
|
13 |
+
class TransformNet(nn.Module):
|
14 |
+
def __init__(self, rnd_bri=0.3, rnd_hue=0.1, do_jpeg=False, jpeg_quality=50, rnd_noise=0.02, rnd_sat=1.0, rnd_trans=0.1,contrast=[0.5, 1.5], rnd_flip=False, ramp=1000, imagenetc_level=0, crop_mode='crop') -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.rnd_bri = rnd_bri
|
17 |
+
self.rnd_hue = rnd_hue
|
18 |
+
self.jpeg_quality = jpeg_quality
|
19 |
+
self.rnd_noise = rnd_noise
|
20 |
+
self.rnd_sat = rnd_sat
|
21 |
+
self.rnd_trans = rnd_trans
|
22 |
+
self.contrast_low, self.contrast_high = contrast
|
23 |
+
self.do_jpeg = do_jpeg
|
24 |
+
p_flip = 0.5 if rnd_flip else 0
|
25 |
+
self.rnd_flip = ko.augmentation.RandomHorizontalFlip(p_flip)
|
26 |
+
self.ramp = ramp
|
27 |
+
self.register_buffer('step0', torch.tensor(0)) # large number
|
28 |
+
assert crop_mode in ['crop', 'resized_crop']
|
29 |
+
if crop_mode == 'crop':
|
30 |
+
self.rnd_crop = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
|
31 |
+
elif crop_mode == 'resized_crop':
|
32 |
+
self.rnd_crop = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
|
33 |
+
if imagenetc_level > 0:
|
34 |
+
self.imagenetc = ImagenetCTransform(max_severity=imagenetc_level)
|
35 |
+
|
36 |
+
def activate(self, global_step):
|
37 |
+
if self.step0 == 0:
|
38 |
+
print(f'[TRAINING] Activating TransformNet at step {global_step}')
|
39 |
+
self.step0 = torch.tensor(global_step)
|
40 |
+
|
41 |
+
def is_activated(self):
|
42 |
+
return self.step0 > 0
|
43 |
+
|
44 |
+
def forward(self, x, global_step, p=0.9):
|
45 |
+
# x: [batch_size, 3, H, W] in range [-1, 1]
|
46 |
+
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
|
47 |
+
|
48 |
+
# flip
|
49 |
+
x = self.rnd_flip(x)
|
50 |
+
# random crop
|
51 |
+
x = self.rnd_crop(x)
|
52 |
+
if isinstance(x, tuple):
|
53 |
+
x = x[0] # weird bug in kornia 0.6.0 that returns transform matrix occasionally
|
54 |
+
|
55 |
+
if torch.rand(1)[0] >= p:
|
56 |
+
return x * 2 - 1 # [0, 1] -> [-1, 1]
|
57 |
+
if hasattr(self, 'imagenetc') and torch.rand(1)[0] < 0.5:
|
58 |
+
x = self.imagenetc(x * 2 - 1) # [0, 1] -> [-1, 1])
|
59 |
+
return x
|
60 |
+
|
61 |
+
batch_size, sh, device = x.shape[0], x.size(), x.device
|
62 |
+
# x0 = x.clone().detach()
|
63 |
+
ramp_fn = lambda ramp: np.min([(global_step-self.step0.cpu().item()) / ramp, 1.])
|
64 |
+
|
65 |
+
rnd_bri = ramp_fn(self.ramp) * self.rnd_bri
|
66 |
+
rnd_hue = ramp_fn(self.ramp) * self.rnd_hue
|
67 |
+
rnd_brightness = utils.get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size).to(device) # [batch_size, 3, 1, 1]
|
68 |
+
rnd_noise = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_noise
|
69 |
+
|
70 |
+
contrast_low = 1. - (1. - self.contrast_low) * ramp_fn(self.ramp)
|
71 |
+
contrast_high = 1. + (self.contrast_high - 1.) * ramp_fn(self.ramp)
|
72 |
+
contrast_params = [contrast_low, contrast_high]
|
73 |
+
|
74 |
+
# blur
|
75 |
+
N_blur = 7
|
76 |
+
f = utils.random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.],
|
77 |
+
wmin_line=3).to(device)
|
78 |
+
x = F.conv2d(x, f, bias=None, padding=int((N_blur - 1) / 2))
|
79 |
+
|
80 |
+
# noise
|
81 |
+
noise = torch.normal(mean=0, std=rnd_noise, size=x.size(), dtype=torch.float32).to(device)
|
82 |
+
x = x + noise
|
83 |
+
x = torch.clamp(x, 0, 1)
|
84 |
+
|
85 |
+
# contrast & brightness
|
86 |
+
contrast_scale = torch.Tensor(x.size()[0]).uniform_(contrast_params[0], contrast_params[1])
|
87 |
+
contrast_scale = contrast_scale.reshape(x.size()[0], 1, 1, 1).to(device)
|
88 |
+
x = x * contrast_scale
|
89 |
+
x = x + rnd_brightness
|
90 |
+
x = torch.clamp(x, 0, 1)
|
91 |
+
|
92 |
+
# saturation
|
93 |
+
# rnd_sat = torch.rand(1)[0] * ramp_fn(self.ramp) * self.rnd_sat
|
94 |
+
# sat_weight = torch.FloatTensor([.3, .6, .1]).reshape(1, 3, 1, 1).to(device)
|
95 |
+
# encoded_image_lum = torch.mean(x * sat_weight, dim=1).unsqueeze_(1)
|
96 |
+
# x = (1 - rnd_sat) * x + rnd_sat * encoded_image_lum
|
97 |
+
rnd_sat = (torch.rand(1)[0]*2.0 - 1.0)*ramp_fn(self.ramp) * self.rnd_sat + 1.0
|
98 |
+
x = ko.enhance.adjust.adjust_saturation(x, rnd_sat)
|
99 |
+
|
100 |
+
# jpeg
|
101 |
+
x = x.reshape(sh)
|
102 |
+
if self.do_jpeg:
|
103 |
+
jpeg_quality = 100. - torch.rand(1)[0] * ramp_fn(self.ramp) * (100. - self.jpeg_quality)
|
104 |
+
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
|
105 |
+
|
106 |
+
x = x * 2 - 1 # [0, 1] -> [-1, 1]
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class ImagenetCTransform(nn.Module):
|
111 |
+
def __init__(self, max_severity=5) -> None:
|
112 |
+
super().__init__()
|
113 |
+
self.max_severity = max_severity
|
114 |
+
self.tform = RandomImagenetC(max_severity=max_severity, phase='train')
|
115 |
+
|
116 |
+
def forward(self, x):
|
117 |
+
# x: [batch_size, 3, H, W] in range [-1, 1]
|
118 |
+
img0 = x.detach().cpu().numpy()
|
119 |
+
img = img0 * 127.5 + 127.5 # [-1, 1] -> [0, 255]
|
120 |
+
img = img.transpose(0, 2, 3, 1).astype(np.uint8)
|
121 |
+
img = [Image.fromarray(i) for i in img]
|
122 |
+
img = [self.tform(i) for i in img]
|
123 |
+
img = np.array([np.array(i) for i in img], dtype=np.float32)
|
124 |
+
img = img.transpose(0, 3, 1, 2) / 127.5 - 1. # [0, 255] -> [-1, 1]
|
125 |
+
residual = torch.from_numpy(img - img0).to(x.device)
|
126 |
+
x = x + residual
|
127 |
+
return x
|
cldm/transformations2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
use kornia and albumentations for transformations
|
5 |
+
@author: Tu Bui @University of Surrey
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from . import utils
|
9 |
+
import torch
|
10 |
+
import numpy as np
|
11 |
+
from torch import nn
|
12 |
+
import torch.nn.functional as thf
|
13 |
+
from PIL import Image
|
14 |
+
import kornia as ko
|
15 |
+
import albumentations as ab
|
16 |
+
from torchvision import transforms
|
17 |
+
|
18 |
+
|
19 |
+
class IdentityAugment(nn.Module):
|
20 |
+
def __init__(self):
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
def forward(self, x, **kwargs):
|
24 |
+
return x
|
25 |
+
|
26 |
+
|
27 |
+
class RandomCompress(nn.Module):
|
28 |
+
def __init__(self, severity='medium', p=0.5):
|
29 |
+
super().__init__()
|
30 |
+
self.p = p
|
31 |
+
if severity == 'low':
|
32 |
+
self.jpeg_quality = 70
|
33 |
+
elif severity == 'medium':
|
34 |
+
self.jpeg_quality = 50
|
35 |
+
elif severity == 'high':
|
36 |
+
self.jpeg_quality = 40
|
37 |
+
|
38 |
+
def forward(self, x, ramp=1.):
|
39 |
+
# x (B, C, H, W) in range [0, 1]
|
40 |
+
# ramp: adjust the ramping of the compression, 1.0 means min quality = self.jpeg_quality
|
41 |
+
if torch.rand(1)[0] >= self.p:
|
42 |
+
return x
|
43 |
+
jpeg_quality = 100. - torch.rand(1)[0] * ramp * (100. - self.jpeg_quality)
|
44 |
+
x = utils.jpeg_compress_decompress(x, rounding=utils.round_only_at_0, quality=jpeg_quality)
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
class RandomBoxBlur(nn.Module):
|
49 |
+
def __init__(self, severity='medium', border_type='reflect', normalized=True, p=0.5):
|
50 |
+
super().__init__()
|
51 |
+
self.p = p
|
52 |
+
if severity == 'low':
|
53 |
+
kernel_size = 3
|
54 |
+
elif severity == 'medium':
|
55 |
+
kernel_size = 5
|
56 |
+
elif severity == 'high':
|
57 |
+
kernel_size = 7
|
58 |
+
|
59 |
+
self.tform = ko.augmentation.RandomBoxBlur(kernel_size=(kernel_size, kernel_size), border_type=border_type, normalized=normalized, p=self.p)
|
60 |
+
|
61 |
+
def forward(self, x, **kwargs):
|
62 |
+
return self.tform(x)
|
63 |
+
|
64 |
+
class RandomMedianBlur(nn.Module):
|
65 |
+
def __init__(self, severity='medium', p=0.5):
|
66 |
+
super().__init__()
|
67 |
+
self.p = p
|
68 |
+
self.tform = ko.augmentation.RandomMedianBlur(kernel_size=(3,3), p=p)
|
69 |
+
|
70 |
+
def forward(self, x, **kwargs):
|
71 |
+
return self.tform(x)
|
72 |
+
|
73 |
+
|
74 |
+
class RandomBrightness(nn.Module):
|
75 |
+
def __init__(self, severity='medium', p=0.5):
|
76 |
+
super().__init__()
|
77 |
+
self.p = p
|
78 |
+
if severity == 'low':
|
79 |
+
brightness = (0.9, 1.1)
|
80 |
+
elif severity == 'medium':
|
81 |
+
brightness = (0.75, 1.25)
|
82 |
+
elif severity == 'high':
|
83 |
+
brightness = (0.5, 1.5)
|
84 |
+
self.tform = ko.augmentation.RandomBrightness(brightness=brightness, p=p)
|
85 |
+
|
86 |
+
def forward(self, x, **kwargs):
|
87 |
+
return self.tform(x)
|
88 |
+
|
89 |
+
|
90 |
+
class RandomContrast(nn.Module):
|
91 |
+
def __init__(self, severity='medium', p=0.5):
|
92 |
+
super().__init__()
|
93 |
+
self.p = p
|
94 |
+
if severity == 'low':
|
95 |
+
contrast = (0.9, 1.1)
|
96 |
+
elif severity == 'medium':
|
97 |
+
contrast = (0.75, 1.25)
|
98 |
+
elif severity == 'high':
|
99 |
+
contrast = (0.5, 1.5)
|
100 |
+
self.tform = ko.augmentation.RandomContrast(contrast=contrast, p=p)
|
101 |
+
|
102 |
+
def forward(self, x, **kwargs):
|
103 |
+
return self.tform(x)
|
104 |
+
|
105 |
+
|
106 |
+
class RandomSaturation(nn.Module):
|
107 |
+
def __init__(self, severity='medium', p=0.5):
|
108 |
+
super().__init__()
|
109 |
+
self.p = p
|
110 |
+
if severity == 'low':
|
111 |
+
sat = (0.9, 1.1)
|
112 |
+
elif severity == 'medium':
|
113 |
+
sat = (0.75, 1.25)
|
114 |
+
elif severity == 'high':
|
115 |
+
sat = (0.5, 1.5)
|
116 |
+
self.tform = ko.augmentation.RandomSaturation(saturation=sat, p=p)
|
117 |
+
|
118 |
+
def forward(self, x, **kwargs):
|
119 |
+
return self.tform(x)
|
120 |
+
|
121 |
+
class RandomSharpness(nn.Module):
|
122 |
+
def __init__(self, severity='medium', p=0.5):
|
123 |
+
super().__init__()
|
124 |
+
self.p = p
|
125 |
+
if severity == 'low':
|
126 |
+
sharpness = 0.5
|
127 |
+
elif severity == 'medium':
|
128 |
+
sharpness = 1.0
|
129 |
+
elif severity == 'high':
|
130 |
+
sharpness = 2.5
|
131 |
+
self.tform = ko.augmentation.RandomSharpness(sharpness=sharpness, p=p)
|
132 |
+
|
133 |
+
def forward(self, x, **kwargs):
|
134 |
+
return self.tform(x)
|
135 |
+
|
136 |
+
class RandomColorJiggle(nn.Module):
|
137 |
+
def __init__(self, severity='medium', p=0.5):
|
138 |
+
super().__init__()
|
139 |
+
self.p = p
|
140 |
+
if severity == 'low':
|
141 |
+
factor = (0.05, 0.05, 0.05, 0.01)
|
142 |
+
elif severity == 'medium':
|
143 |
+
factor = (0.1, 0.1, 0.1, 0.02)
|
144 |
+
elif severity == 'high':
|
145 |
+
factor = (0.1, 0.1, 0.1, 0.05)
|
146 |
+
self.tform = ko.augmentation.ColorJiggle(*factor, p=p)
|
147 |
+
|
148 |
+
def forward(self, x, **kwargs):
|
149 |
+
return self.tform(x)
|
150 |
+
|
151 |
+
class RandomHue(nn.Module):
|
152 |
+
def __init__(self, severity='medium', p=0.5):
|
153 |
+
super().__init__()
|
154 |
+
self.p = p
|
155 |
+
if severity == 'low':
|
156 |
+
hue = 0.01
|
157 |
+
elif severity == 'medium':
|
158 |
+
hue = 0.02
|
159 |
+
elif severity == 'high':
|
160 |
+
hue = 0.05
|
161 |
+
self.tform = ko.augmentation.RandomHue(hue=(-hue, hue), p=p)
|
162 |
+
|
163 |
+
def forward(self, x, **kwargs):
|
164 |
+
return self.tform(x)
|
165 |
+
|
166 |
+
class RandomGamma(nn.Module):
|
167 |
+
def __init__(self, severity='medium', p=0.5):
|
168 |
+
super().__init__()
|
169 |
+
self.p = p
|
170 |
+
if severity == 'low':
|
171 |
+
gamma, gain = (0.9, 1.1), (0.9,1.1)
|
172 |
+
elif severity == 'medium':
|
173 |
+
gamma, gain = (0.75, 1.25), (0.75,1.25)
|
174 |
+
elif severity == 'high':
|
175 |
+
gamma, gain = (0.5, 1.5), (0.5,1.5)
|
176 |
+
self.tform = ko.augmentation.RandomGamma(gamma, gain, p=p)
|
177 |
+
|
178 |
+
def forward(self, x, **kwargs):
|
179 |
+
return self.tform(x)
|
180 |
+
|
181 |
+
class RandomGaussianBlur(nn.Module):
|
182 |
+
def __init__(self, severity='medium', p=0.5):
|
183 |
+
super().__init__()
|
184 |
+
self.p = p
|
185 |
+
if severity == 'low':
|
186 |
+
kernel_size, sigma = 3, (0.1, 1.0)
|
187 |
+
elif severity == 'medium':
|
188 |
+
kernel_size, sigma = 5, (0.1, 1.5)
|
189 |
+
elif severity == 'high':
|
190 |
+
kernel_size, sigma = 7, (0.1, 2.0)
|
191 |
+
self.tform = ko.augmentation.RandomGaussianBlur(kernel_size=(kernel_size, kernel_size), sigma=sigma, p=self.p)
|
192 |
+
|
193 |
+
def forward(self, x, **kwargs):
|
194 |
+
return self.tform(x)
|
195 |
+
|
196 |
+
class RandomGaussianNoise(nn.Module):
|
197 |
+
def __init__(self, severity='medium', p=0.5):
|
198 |
+
super().__init__()
|
199 |
+
self.p = p
|
200 |
+
if severity == 'low':
|
201 |
+
std = 0.02
|
202 |
+
elif severity == 'medium':
|
203 |
+
std = 0.04
|
204 |
+
elif severity == 'high':
|
205 |
+
std = 0.08
|
206 |
+
self.tform = ko.augmentation.RandomGaussianNoise(mean=0., std=std, p=p)
|
207 |
+
|
208 |
+
def forward(self, x, **kwargs):
|
209 |
+
return self.tform(x)
|
210 |
+
|
211 |
+
class RandomMotionBlur(nn.Module):
|
212 |
+
def __init__(self, severity='medium', p=0.5):
|
213 |
+
super().__init__()
|
214 |
+
self.p = p
|
215 |
+
if severity == 'low':
|
216 |
+
kernel_size, angle, direction = (3, 5), (-25, 25), (-0.25, 0.25)
|
217 |
+
elif severity == 'medium':
|
218 |
+
kernel_size, angle, direction = (3, 7), (-45, 45), (-0.5, 0.5)
|
219 |
+
elif severity == 'high':
|
220 |
+
kernel_size, angle, direction = (3, 9), (-90, 90), (-1.0, 1.0)
|
221 |
+
self.tform = ko.augmentation.RandomMotionBlur(kernel_size, angle, direction, p=p)
|
222 |
+
|
223 |
+
def forward(self, x, **kwargs):
|
224 |
+
return self.tform(x)
|
225 |
+
|
226 |
+
class RandomPosterize(nn.Module):
|
227 |
+
def __init__(self, severity='medium', p=0.5):
|
228 |
+
super().__init__()
|
229 |
+
self.p = p
|
230 |
+
if severity == 'low':
|
231 |
+
bits = 5
|
232 |
+
elif severity == 'medium':
|
233 |
+
bits = 4
|
234 |
+
elif severity == 'high':
|
235 |
+
bits = 3
|
236 |
+
self.tform = ko.augmentation.RandomPosterize(bits=bits, p=p)
|
237 |
+
|
238 |
+
def forward(self, x, **kwargs):
|
239 |
+
return self.tform(x)
|
240 |
+
|
241 |
+
class RandomRGBShift(nn.Module):
|
242 |
+
def __init__(self, severity='medium', p=0.5):
|
243 |
+
super().__init__()
|
244 |
+
self.p = p
|
245 |
+
if severity == 'low':
|
246 |
+
rgb = 0.02
|
247 |
+
elif severity == 'medium':
|
248 |
+
rgb = 0.05
|
249 |
+
elif severity == 'high':
|
250 |
+
rgb = 0.1
|
251 |
+
self.tform = ko.augmentation.RandomRGBShift(r_shift_limit=rgb, g_shift_limit=rgb, b_shift_limit=rgb, p=p)
|
252 |
+
|
253 |
+
def forward(self, x, **kwargs):
|
254 |
+
return self.tform(x)
|
255 |
+
|
256 |
+
|
257 |
+
|
258 |
+
class TransformNet(nn.Module):
|
259 |
+
def __init__(self, flip=True, crop_mode='random_crop', compress=True, brightness=True, contrast=True, color_jiggle=True, gamma=False, grayscale=True, gaussian_blur=True, gaussian_noise=True, hue=True, motion_blur=True, posterize=True, rgb_shift=True, saturation=True, sharpness=True, median_blur=True, box_blur=True, severity='medium', n_optional=2, ramp=1000, p=0.5):
|
260 |
+
super().__init__()
|
261 |
+
self.n_optional = n_optional
|
262 |
+
self.p = p
|
263 |
+
p_flip = 0.5 if flip else 0
|
264 |
+
rnd_flip_layer = ko.augmentation.RandomHorizontalFlip(p_flip)
|
265 |
+
self.ramp = ramp
|
266 |
+
self.register_buffer('step0', torch.tensor(0))
|
267 |
+
|
268 |
+
self.crop_mode = crop_mode
|
269 |
+
assert crop_mode in ['random_crop', 'resized_crop']
|
270 |
+
if crop_mode == 'random_crop':
|
271 |
+
rnd_crop_layer = ko.augmentation.RandomCrop((224,224), cropping_mode="resample")
|
272 |
+
elif crop_mode == 'resized_crop':
|
273 |
+
rnd_crop_layer = ko.augmentation.RandomResizedCrop(size=(224,224), scale=(0.7, 1.0), ratio=(3.0/4, 4.0/3), cropping_mode='resample')
|
274 |
+
|
275 |
+
self.fixed_transforms = [rnd_flip_layer, rnd_crop_layer]
|
276 |
+
if compress:
|
277 |
+
self.register(RandomCompress(severity, p=p), 'Random Compress')
|
278 |
+
if brightness:
|
279 |
+
self.register(RandomBrightness(severity, p=p), 'Random Brightness')
|
280 |
+
if contrast:
|
281 |
+
self.register(RandomContrast(severity, p=p), 'Random Contrast')
|
282 |
+
if color_jiggle:
|
283 |
+
self.register(RandomColorJiggle(severity, p=p), 'Random Color')
|
284 |
+
if gamma:
|
285 |
+
self.register(RandomGamma(severity, p=p), 'Random Gamma')
|
286 |
+
if grayscale:
|
287 |
+
self.register(ko.augmentation.RandomGrayscale(p=p), 'Grayscale')
|
288 |
+
if gaussian_blur:
|
289 |
+
self.register(RandomGaussianBlur(severity, p=p), 'Random Gaussian Blur')
|
290 |
+
if gaussian_noise:
|
291 |
+
self.register(RandomGaussianNoise(severity, p=p), 'Random Gaussian Noise')
|
292 |
+
if hue:
|
293 |
+
self.register(RandomHue(severity, p=p), 'Random Hue')
|
294 |
+
if motion_blur:
|
295 |
+
self.register(RandomMotionBlur(severity, p=p), 'Random Motion Blur')
|
296 |
+
if posterize:
|
297 |
+
self.register(RandomPosterize(severity, p=p), 'Random Posterize')
|
298 |
+
if rgb_shift:
|
299 |
+
self.register(RandomRGBShift(severity, p=p), 'Random RGB Shift')
|
300 |
+
if saturation:
|
301 |
+
self.register(RandomSaturation(severity, p=p), 'Random Saturation')
|
302 |
+
if sharpness:
|
303 |
+
self.register(RandomSharpness(severity, p=p), 'Random Sharpness')
|
304 |
+
if median_blur:
|
305 |
+
self.register(RandomMedianBlur(severity, p=p), 'Random Median Blur')
|
306 |
+
if box_blur:
|
307 |
+
self.register(RandomBoxBlur(severity, p=p), 'Random Box Blur')
|
308 |
+
|
309 |
+
def register(self, tform, name):
|
310 |
+
# register a new (optional) transform
|
311 |
+
if not hasattr(self, 'optional_transforms'):
|
312 |
+
self.optional_transforms = []
|
313 |
+
self.optional_names = []
|
314 |
+
self.optional_transforms.append(tform)
|
315 |
+
self.optional_names.append(name)
|
316 |
+
|
317 |
+
def activate(self, global_step):
|
318 |
+
if self.step0 == 0:
|
319 |
+
print(f'[TRAINING] Activating TransformNet at step {global_step}')
|
320 |
+
self.step0 = torch.tensor(global_step)
|
321 |
+
|
322 |
+
def is_activated(self):
|
323 |
+
return self.step0 > 0
|
324 |
+
|
325 |
+
def forward(self, x, global_step, p=0.9):
|
326 |
+
# x: [batch_size, 3, H, W] in range [-1, 1]
|
327 |
+
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
|
328 |
+
# fixed transforms
|
329 |
+
for tform in self.fixed_transforms:
|
330 |
+
x = tform(x)
|
331 |
+
if isinstance(x, tuple):
|
332 |
+
x = x[0]
|
333 |
+
|
334 |
+
# optional transforms
|
335 |
+
ramp = np.min([(global_step-self.step0.cpu().item()) / self.ramp, 1.])
|
336 |
+
if len(self.optional_transforms) > 0:
|
337 |
+
tform_ids = torch.randint(len(self.optional_transforms), (self.n_optional,)).numpy()
|
338 |
+
for tform_id in tform_ids:
|
339 |
+
tform = self.optional_transforms[tform_id]
|
340 |
+
x = tform(x, ramp=ramp)
|
341 |
+
if isinstance(x, tuple):
|
342 |
+
x = x[0]
|
343 |
+
|
344 |
+
return x * 2 - 1 # [0, 1] -> [-1, 1]
|
345 |
+
|
346 |
+
def transform_by_id(self, x, tform_id):
|
347 |
+
# x: [batch_size, 3, H, W] in range [-1, 1]
|
348 |
+
x = x * 0.5 + 0.5 # [-1, 1] -> [0, 1]
|
349 |
+
# fixed transforms
|
350 |
+
for tform in self.fixed_transforms:
|
351 |
+
x = tform(x)
|
352 |
+
if isinstance(x, tuple):
|
353 |
+
x = x[0]
|
354 |
+
|
355 |
+
# optional transforms
|
356 |
+
tform = self.optional_transforms[tform_id]
|
357 |
+
x = tform(x)
|
358 |
+
if isinstance(x, tuple):
|
359 |
+
x = x[0]
|
360 |
+
return x * 2 - 1 # [0, 1] -> [-1, 1]
|
361 |
+
|
362 |
+
def transform_by_name(self, x, tform_name):
|
363 |
+
assert tform_name in self.optional_names
|
364 |
+
tform_id = self.optional_names.index(tform_name)
|
365 |
+
return self.transform_by_id(x, tform_id)
|
366 |
+
|
367 |
+
def apply_transform_on_pil_image(self, x, tform_name):
|
368 |
+
# x: PIL image
|
369 |
+
# return: PIL image
|
370 |
+
assert tform_name in self.optional_names + ['Fixed Augment']
|
371 |
+
# if tform_name == 'Random Crop': # the only transform dependent on image size
|
372 |
+
# # crop equivalent to 224/256
|
373 |
+
# w, h = x.size
|
374 |
+
# new_w, new_h = int(224 / 256 * w), int(224 / 256 * h)
|
375 |
+
# x = transforms.RandomCrop((new_h, new_w))(x)
|
376 |
+
# return x
|
377 |
+
|
378 |
+
# x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
|
379 |
+
# x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
380 |
+
# if tform_name == 'Random Flip':
|
381 |
+
# x = self.fixed_transforms[0](x)
|
382 |
+
# else:
|
383 |
+
# tform_id = self.optional_names.index(tform_name)
|
384 |
+
# tform = self.optional_transforms[tform_id]
|
385 |
+
# x = tform(x)
|
386 |
+
# if isinstance(x, tuple):
|
387 |
+
# x = x[0]
|
388 |
+
# x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
|
389 |
+
# return Image.fromarray(x.astype(np.uint8))
|
390 |
+
|
391 |
+
w, h = x.size
|
392 |
+
x = x.resize((256, 256), Image.BILINEAR)
|
393 |
+
x = np.array(x).astype(np.float32) / 255. # [0, 255] -> [0, 1]
|
394 |
+
x = torch.from_numpy(x).permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
395 |
+
if tform_name == 'Fixed Augment':
|
396 |
+
for tform in self.fixed_transforms:
|
397 |
+
x = tform(x)
|
398 |
+
if isinstance(x, tuple):
|
399 |
+
x = x[0]
|
400 |
+
else:
|
401 |
+
tform_id = self.optional_names.index(tform_name)
|
402 |
+
tform = self.optional_transforms[tform_id]
|
403 |
+
x = tform(x)
|
404 |
+
if isinstance(x, tuple):
|
405 |
+
x = x[0]
|
406 |
+
x = x.detach().squeeze(0).permute(1, 2, 0).numpy() * 255 # [0, 1] -> [0, 255]
|
407 |
+
x = Image.fromarray(x.astype(np.uint8))
|
408 |
+
if (tform_name == 'Random Crop') and (self.crop_mode == 'random_crop'):
|
409 |
+
w, h = int(224 / 256 * w), int(224 / 256 * h)
|
410 |
+
x = x.resize((w, h), Image.BILINEAR)
|
411 |
+
return x
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == '__main__':
|
415 |
+
pass
|
cldm/utils.py
ADDED
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import itertools
|
3 |
+
import numpy as np
|
4 |
+
import random
|
5 |
+
import torch
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torch.nn as nn
|
8 |
+
|
9 |
+
from PIL import Image, ImageOps
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
|
12 |
+
def random_blur_kernel(probs, N_blur, sigrange_gauss, sigrange_line, wmin_line):
|
13 |
+
N = N_blur
|
14 |
+
coords = torch.from_numpy(np.stack(np.meshgrid(range(N_blur), range(N_blur), indexing='ij'), axis=-1)) - (0.5 * (N-1)) # (7,7,2)
|
15 |
+
manhat = torch.sum(torch.abs(coords), dim=-1) # (7, 7)
|
16 |
+
|
17 |
+
# nothing, default
|
18 |
+
vals_nothing = (manhat < 0.5).float() # (7, 7)
|
19 |
+
|
20 |
+
# gauss
|
21 |
+
sig_gauss = torch.rand(1)[0] * (sigrange_gauss[1] - sigrange_gauss[0]) + sigrange_gauss[0]
|
22 |
+
vals_gauss = torch.exp(-torch.sum(coords ** 2, dim=-1) /2. / sig_gauss ** 2)
|
23 |
+
|
24 |
+
# line
|
25 |
+
theta = torch.rand(1)[0] * 2.* np.pi
|
26 |
+
v = torch.FloatTensor([torch.cos(theta), torch.sin(theta)]) # (2)
|
27 |
+
dists = torch.sum(coords * v, dim=-1) # (7, 7)
|
28 |
+
|
29 |
+
sig_line = torch.rand(1)[0] * (sigrange_line[1] - sigrange_line[0]) + sigrange_line[0]
|
30 |
+
w_line = torch.rand(1)[0] * (0.5 * (N-1) + 0.1 - wmin_line) + wmin_line
|
31 |
+
|
32 |
+
vals_line = torch.exp(-dists ** 2 / 2. / sig_line ** 2) * (manhat < w_line) # (7, 7)
|
33 |
+
|
34 |
+
t = torch.rand(1)[0]
|
35 |
+
vals = vals_nothing
|
36 |
+
if t < (probs[0] + probs[1]):
|
37 |
+
vals = vals_line
|
38 |
+
else:
|
39 |
+
vals = vals
|
40 |
+
if t < probs[0]:
|
41 |
+
vals = vals_gauss
|
42 |
+
else:
|
43 |
+
vals = vals
|
44 |
+
|
45 |
+
v = vals / torch.sum(vals) # 归一化 (7, 7)
|
46 |
+
z = torch.zeros_like(v)
|
47 |
+
f = torch.stack([v,z,z, z,v,z, z,z,v], dim=0).reshape([3, 3, N, N])
|
48 |
+
return f
|
49 |
+
|
50 |
+
|
51 |
+
def get_rand_transform_matrix(image_size, d, batch_size):
|
52 |
+
Ms = np.zeros((batch_size, 2, 3, 3))
|
53 |
+
for i in range(batch_size):
|
54 |
+
tl_x = random.uniform(-d, d) # Top left corner, top
|
55 |
+
tl_y = random.uniform(-d, d) # Top left corner, left
|
56 |
+
bl_x = random.uniform(-d, d) # Bot left corner, bot
|
57 |
+
bl_y = random.uniform(-d, d) # Bot left corner, left
|
58 |
+
tr_x = random.uniform(-d, d) # Top right corner, top
|
59 |
+
tr_y = random.uniform(-d, d) # Top right corner, right
|
60 |
+
br_x = random.uniform(-d, d) # Bot right corner, bot
|
61 |
+
br_y = random.uniform(-d, d) # Bot right corner, right
|
62 |
+
|
63 |
+
rect = np.array([
|
64 |
+
[tl_x, tl_y],
|
65 |
+
[tr_x + image_size, tr_y],
|
66 |
+
[br_x + image_size, br_y + image_size],
|
67 |
+
[bl_x, bl_y + image_size]], dtype = "float32")
|
68 |
+
|
69 |
+
dst = np.array([
|
70 |
+
[0, 0],
|
71 |
+
[image_size, 0],
|
72 |
+
[image_size, image_size],
|
73 |
+
[0, image_size]], dtype = "float32")
|
74 |
+
|
75 |
+
M = cv2.getPerspectiveTransform(rect, dst)
|
76 |
+
M_inv = np.linalg.inv(M)
|
77 |
+
Ms[i, 0, :, :] = M_inv
|
78 |
+
Ms[i, 1, :, :] = M
|
79 |
+
Ms = torch.from_numpy(Ms).float()
|
80 |
+
|
81 |
+
return Ms
|
82 |
+
|
83 |
+
|
84 |
+
def get_rnd_brightness_torch(rnd_bri, rnd_hue, batch_size):
|
85 |
+
rnd_hue = torch.FloatTensor(batch_size, 3, 1, 1).uniform_(-rnd_hue, rnd_hue)
|
86 |
+
rnd_brightness = torch.FloatTensor(batch_size, 1, 1, 1).uniform_(-rnd_bri, rnd_bri)
|
87 |
+
return rnd_hue + rnd_brightness
|
88 |
+
|
89 |
+
|
90 |
+
# reference: https://github.com/mlomnitz/DiffJPEG.git
|
91 |
+
y_table = np.array(
|
92 |
+
[[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60,
|
93 |
+
55], [14, 13, 16, 24, 40, 57, 69, 56],
|
94 |
+
[14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103,
|
95 |
+
77], [24, 35, 55, 64, 81, 104, 113, 92],
|
96 |
+
[49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
|
97 |
+
dtype=np.float32).T
|
98 |
+
|
99 |
+
y_table = nn.Parameter(torch.from_numpy(y_table))
|
100 |
+
c_table = np.empty((8, 8), dtype=np.float32)
|
101 |
+
c_table.fill(99)
|
102 |
+
c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66],
|
103 |
+
[24, 26, 56, 99], [47, 66, 99, 99]]).T
|
104 |
+
c_table = nn.Parameter(torch.from_numpy(c_table))
|
105 |
+
|
106 |
+
# 1. RGB -> YCbCr
|
107 |
+
class rgb_to_ycbcr_jpeg(nn.Module):
|
108 |
+
""" Converts RGB image to YCbCr
|
109 |
+
Input:
|
110 |
+
image(tensor): batch x 3 x height x width
|
111 |
+
Outpput:
|
112 |
+
result(tensor): batch x height x width x 3
|
113 |
+
"""
|
114 |
+
def __init__(self):
|
115 |
+
super(rgb_to_ycbcr_jpeg, self).__init__()
|
116 |
+
matrix = np.array(
|
117 |
+
[[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5],
|
118 |
+
[0.5, -0.418688, -0.081312]], dtype=np.float32).T
|
119 |
+
self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
|
120 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
121 |
+
|
122 |
+
def forward(self, image):
|
123 |
+
image = image.permute(0, 2, 3, 1)
|
124 |
+
result = torch.tensordot(image, self.matrix, dims=1) + self.shift
|
125 |
+
result.view(image.shape)
|
126 |
+
return result
|
127 |
+
|
128 |
+
# 2. Chroma subsampling
|
129 |
+
class chroma_subsampling(nn.Module):
|
130 |
+
""" Chroma subsampling on CbCv channels
|
131 |
+
Input:
|
132 |
+
image(tensor): batch x height x width x 3
|
133 |
+
Output:
|
134 |
+
y(tensor): batch x height x width
|
135 |
+
cb(tensor): batch x height/2 x width/2
|
136 |
+
cr(tensor): batch x height/2 x width/2
|
137 |
+
"""
|
138 |
+
def __init__(self):
|
139 |
+
super(chroma_subsampling, self).__init__()
|
140 |
+
|
141 |
+
def forward(self, image):
|
142 |
+
image_2 = image.permute(0, 3, 1, 2).clone()
|
143 |
+
avg_pool = nn.AvgPool2d(kernel_size=2, stride=(2, 2),
|
144 |
+
count_include_pad=False)
|
145 |
+
cb = avg_pool(image_2[:, 1, :, :].unsqueeze(1))
|
146 |
+
cr = avg_pool(image_2[:, 2, :, :].unsqueeze(1))
|
147 |
+
cb = cb.permute(0, 2, 3, 1)
|
148 |
+
cr = cr.permute(0, 2, 3, 1)
|
149 |
+
return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3)
|
150 |
+
|
151 |
+
# 3. Block splitting
|
152 |
+
class block_splitting(nn.Module):
|
153 |
+
""" Splitting image into patches
|
154 |
+
Input:
|
155 |
+
image(tensor): batch x height x width
|
156 |
+
Output:
|
157 |
+
patch(tensor): batch x h*w/64 x h x w
|
158 |
+
"""
|
159 |
+
def __init__(self):
|
160 |
+
super(block_splitting, self).__init__()
|
161 |
+
self.k = 8
|
162 |
+
|
163 |
+
def forward(self, image):
|
164 |
+
height, width = image.shape[1:3]
|
165 |
+
batch_size = image.shape[0]
|
166 |
+
image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k)
|
167 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
168 |
+
return image_transposed.contiguous().view(batch_size, -1, self.k, self.k)
|
169 |
+
|
170 |
+
# 4. DCT
|
171 |
+
class dct_8x8(nn.Module):
|
172 |
+
""" Discrete Cosine Transformation
|
173 |
+
Input:
|
174 |
+
image(tensor): batch x height x width
|
175 |
+
Output:
|
176 |
+
dcp(tensor): batch x height x width
|
177 |
+
"""
|
178 |
+
def __init__(self):
|
179 |
+
super(dct_8x8, self).__init__()
|
180 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
181 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
182 |
+
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos(
|
183 |
+
(2 * y + 1) * v * np.pi / 16)
|
184 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
185 |
+
#
|
186 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
187 |
+
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float() )
|
188 |
+
|
189 |
+
def forward(self, image):
|
190 |
+
image = image - 128
|
191 |
+
result = self.scale * torch.tensordot(image, self.tensor, dims=2)
|
192 |
+
result.view(image.shape)
|
193 |
+
return result
|
194 |
+
|
195 |
+
# 5. Quantization
|
196 |
+
class y_quantize(nn.Module):
|
197 |
+
""" JPEG Quantization for Y channel
|
198 |
+
Input:
|
199 |
+
image(tensor): batch x height x width
|
200 |
+
rounding(function): rounding function to use
|
201 |
+
factor(float): Degree of compression
|
202 |
+
Output:
|
203 |
+
image(tensor): batch x height x width
|
204 |
+
"""
|
205 |
+
def __init__(self, rounding, factor=1):
|
206 |
+
super(y_quantize, self).__init__()
|
207 |
+
self.rounding = rounding
|
208 |
+
self.factor = factor
|
209 |
+
self.y_table = y_table
|
210 |
+
|
211 |
+
def forward(self, image):
|
212 |
+
image = image.float() / (self.y_table * self.factor)
|
213 |
+
image = self.rounding(image)
|
214 |
+
return image
|
215 |
+
|
216 |
+
|
217 |
+
class c_quantize(nn.Module):
|
218 |
+
""" JPEG Quantization for CrCb channels
|
219 |
+
Input:
|
220 |
+
image(tensor): batch x height x width
|
221 |
+
rounding(function): rounding function to use
|
222 |
+
factor(float): Degree of compression
|
223 |
+
Output:
|
224 |
+
image(tensor): batch x height x width
|
225 |
+
"""
|
226 |
+
def __init__(self, rounding, factor=1):
|
227 |
+
super(c_quantize, self).__init__()
|
228 |
+
self.rounding = rounding
|
229 |
+
self.factor = factor
|
230 |
+
self.c_table = c_table
|
231 |
+
|
232 |
+
def forward(self, image):
|
233 |
+
image = image.float() / (self.c_table * self.factor)
|
234 |
+
image = self.rounding(image)
|
235 |
+
return image
|
236 |
+
|
237 |
+
|
238 |
+
class compress_jpeg(nn.Module):
|
239 |
+
""" Full JPEG compression algortihm
|
240 |
+
Input:
|
241 |
+
imgs(tensor): batch x 3 x height x width
|
242 |
+
rounding(function): rounding function to use
|
243 |
+
factor(float): Compression factor
|
244 |
+
Ouput:
|
245 |
+
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
246 |
+
"""
|
247 |
+
def __init__(self, rounding=torch.round, factor=1):
|
248 |
+
super(compress_jpeg, self).__init__()
|
249 |
+
self.l1 = nn.Sequential(
|
250 |
+
rgb_to_ycbcr_jpeg(),
|
251 |
+
chroma_subsampling()
|
252 |
+
)
|
253 |
+
self.l2 = nn.Sequential(
|
254 |
+
block_splitting(),
|
255 |
+
dct_8x8()
|
256 |
+
)
|
257 |
+
self.c_quantize = c_quantize(rounding=rounding, factor=factor)
|
258 |
+
self.y_quantize = y_quantize(rounding=rounding, factor=factor)
|
259 |
+
|
260 |
+
def forward(self, image):
|
261 |
+
y, cb, cr = self.l1(image*255)
|
262 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
263 |
+
for k in components.keys():
|
264 |
+
comp = self.l2(components[k])
|
265 |
+
if k in ('cb', 'cr'):
|
266 |
+
comp = self.c_quantize(comp)
|
267 |
+
else:
|
268 |
+
comp = self.y_quantize(comp)
|
269 |
+
|
270 |
+
components[k] = comp
|
271 |
+
|
272 |
+
return components['y'], components['cb'], components['cr']
|
273 |
+
|
274 |
+
# -5. Dequantization
|
275 |
+
class y_dequantize(nn.Module):
|
276 |
+
""" Dequantize Y channel
|
277 |
+
Inputs:
|
278 |
+
image(tensor): batch x height x width
|
279 |
+
factor(float): compression factor
|
280 |
+
Outputs:
|
281 |
+
image(tensor): batch x height x width
|
282 |
+
"""
|
283 |
+
def __init__(self, factor=1):
|
284 |
+
super(y_dequantize, self).__init__()
|
285 |
+
self.y_table = y_table
|
286 |
+
self.factor = factor
|
287 |
+
|
288 |
+
def forward(self, image):
|
289 |
+
return image * (self.y_table * self.factor)
|
290 |
+
|
291 |
+
|
292 |
+
class c_dequantize(nn.Module):
|
293 |
+
""" Dequantize CbCr channel
|
294 |
+
Inputs:
|
295 |
+
image(tensor): batch x height x width
|
296 |
+
factor(float): compression factor
|
297 |
+
Outputs:
|
298 |
+
image(tensor): batch x height x width
|
299 |
+
"""
|
300 |
+
def __init__(self, factor=1):
|
301 |
+
super(c_dequantize, self).__init__()
|
302 |
+
self.factor = factor
|
303 |
+
self.c_table = c_table
|
304 |
+
|
305 |
+
def forward(self, image):
|
306 |
+
return image * (self.c_table * self.factor)
|
307 |
+
|
308 |
+
# -4. Inverse DCT
|
309 |
+
class idct_8x8(nn.Module):
|
310 |
+
""" Inverse discrete Cosine Transformation
|
311 |
+
Input:
|
312 |
+
dcp(tensor): batch x height x width
|
313 |
+
Output:
|
314 |
+
image(tensor): batch x height x width
|
315 |
+
"""
|
316 |
+
def __init__(self):
|
317 |
+
super(idct_8x8, self).__init__()
|
318 |
+
alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
|
319 |
+
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
|
320 |
+
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
|
321 |
+
for x, y, u, v in itertools.product(range(8), repeat=4):
|
322 |
+
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
|
323 |
+
(2 * v + 1) * y * np.pi / 16)
|
324 |
+
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
|
325 |
+
|
326 |
+
def forward(self, image):
|
327 |
+
image = image * self.alpha
|
328 |
+
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
|
329 |
+
result.view(image.shape)
|
330 |
+
return result
|
331 |
+
|
332 |
+
# -3. Block joining
|
333 |
+
class block_merging(nn.Module):
|
334 |
+
""" Merge pathces into image
|
335 |
+
Inputs:
|
336 |
+
patches(tensor) batch x height*width/64, height x width
|
337 |
+
height(int)
|
338 |
+
width(int)
|
339 |
+
Output:
|
340 |
+
image(tensor): batch x height x width
|
341 |
+
"""
|
342 |
+
def __init__(self):
|
343 |
+
super(block_merging, self).__init__()
|
344 |
+
|
345 |
+
def forward(self, patches, height, width):
|
346 |
+
k = 8
|
347 |
+
batch_size = patches.shape[0]
|
348 |
+
image_reshaped = patches.view(batch_size, height//k, width//k, k, k)
|
349 |
+
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
|
350 |
+
return image_transposed.contiguous().view(batch_size, height, width)
|
351 |
+
|
352 |
+
# -2. Chroma upsampling
|
353 |
+
class chroma_upsampling(nn.Module):
|
354 |
+
""" Upsample chroma layers
|
355 |
+
Input:
|
356 |
+
y(tensor): y channel image
|
357 |
+
cb(tensor): cb channel
|
358 |
+
cr(tensor): cr channel
|
359 |
+
Ouput:
|
360 |
+
image(tensor): batch x height x width x 3
|
361 |
+
"""
|
362 |
+
def __init__(self):
|
363 |
+
super(chroma_upsampling, self).__init__()
|
364 |
+
|
365 |
+
def forward(self, y, cb, cr):
|
366 |
+
def repeat(x, k=2):
|
367 |
+
height, width = x.shape[1:3]
|
368 |
+
x = x.unsqueeze(-1)
|
369 |
+
x = x.repeat(1, 1, k, k)
|
370 |
+
x = x.view(-1, height * k, width * k)
|
371 |
+
return x
|
372 |
+
|
373 |
+
cb = repeat(cb)
|
374 |
+
cr = repeat(cr)
|
375 |
+
|
376 |
+
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
|
377 |
+
|
378 |
+
# -1: YCbCr -> RGB
|
379 |
+
class ycbcr_to_rgb_jpeg(nn.Module):
|
380 |
+
""" Converts YCbCr image to RGB JPEG
|
381 |
+
Input:
|
382 |
+
image(tensor): batch x height x width x 3
|
383 |
+
Outpput:
|
384 |
+
result(tensor): batch x 3 x height x width
|
385 |
+
"""
|
386 |
+
def __init__(self):
|
387 |
+
super(ycbcr_to_rgb_jpeg, self).__init__()
|
388 |
+
|
389 |
+
matrix = np.array(
|
390 |
+
[[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
|
391 |
+
dtype=np.float32).T
|
392 |
+
self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
|
393 |
+
self.matrix = nn.Parameter(torch.from_numpy(matrix))
|
394 |
+
|
395 |
+
def forward(self, image):
|
396 |
+
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
|
397 |
+
result.view(image.shape)
|
398 |
+
return result.permute(0, 3, 1, 2)
|
399 |
+
|
400 |
+
|
401 |
+
class decompress_jpeg(nn.Module):
|
402 |
+
""" Full JPEG decompression algortihm
|
403 |
+
Input:
|
404 |
+
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
|
405 |
+
rounding(function): rounding function to use
|
406 |
+
factor(float): Compression factor
|
407 |
+
Ouput:
|
408 |
+
image(tensor): batch x 3 x height x width
|
409 |
+
"""
|
410 |
+
def __init__(self, height, width, rounding=torch.round, factor=1):
|
411 |
+
super(decompress_jpeg, self).__init__()
|
412 |
+
self.c_dequantize = c_dequantize(factor=factor)
|
413 |
+
self.y_dequantize = y_dequantize(factor=factor)
|
414 |
+
self.idct = idct_8x8()
|
415 |
+
self.merging = block_merging()
|
416 |
+
self.chroma = chroma_upsampling()
|
417 |
+
self.colors = ycbcr_to_rgb_jpeg()
|
418 |
+
|
419 |
+
self.height, self.width = height, width
|
420 |
+
|
421 |
+
def forward(self, y, cb, cr):
|
422 |
+
components = {'y': y, 'cb': cb, 'cr': cr}
|
423 |
+
for k in components.keys():
|
424 |
+
if k in ('cb', 'cr'):
|
425 |
+
comp = self.c_dequantize(components[k])
|
426 |
+
height, width = int(self.height/2), int(self.width/2)
|
427 |
+
else:
|
428 |
+
comp = self.y_dequantize(components[k])
|
429 |
+
height, width = self.height, self.width
|
430 |
+
comp = self.idct(comp)
|
431 |
+
components[k] = self.merging(comp, height, width)
|
432 |
+
#
|
433 |
+
image = self.chroma(components['y'], components['cb'], components['cr'])
|
434 |
+
image = self.colors(image)
|
435 |
+
|
436 |
+
image = torch.min(255*torch.ones_like(image),
|
437 |
+
torch.max(torch.zeros_like(image), image))
|
438 |
+
return image/255
|
439 |
+
|
440 |
+
def diff_round(x):
|
441 |
+
""" Differentiable rounding function
|
442 |
+
Input:
|
443 |
+
x(tensor)
|
444 |
+
Output:
|
445 |
+
x(tensor)
|
446 |
+
"""
|
447 |
+
return torch.round(x) + (x - torch.round(x))**3
|
448 |
+
|
449 |
+
def round_only_at_0(x):
|
450 |
+
cond = (torch.abs(x) < 0.5).float()
|
451 |
+
return cond * (x ** 3) + (1 - cond) * x
|
452 |
+
|
453 |
+
def quality_to_factor(quality):
|
454 |
+
""" Calculate factor corresponding to quality
|
455 |
+
Input:
|
456 |
+
quality(float): Quality for jpeg compression
|
457 |
+
Output:
|
458 |
+
factor(float): Compression factor
|
459 |
+
"""
|
460 |
+
if quality < 50:
|
461 |
+
quality = 5000. / quality
|
462 |
+
else:
|
463 |
+
quality = 200. - quality*2
|
464 |
+
return quality / 100.
|
465 |
+
|
466 |
+
def jpeg_compress_decompress(image,
|
467 |
+
# downsample_c=True,
|
468 |
+
rounding=round_only_at_0,
|
469 |
+
quality=80):
|
470 |
+
# image_r = image * 255
|
471 |
+
height, width = image.shape[2:4]
|
472 |
+
# orig_height, orig_width = height, width
|
473 |
+
# if height % 16 != 0 or width % 16 != 0:
|
474 |
+
# # Round up to next multiple of 16
|
475 |
+
# height = ((height - 1) // 16 + 1) * 16
|
476 |
+
# width = ((width - 1) // 16 + 1) * 16
|
477 |
+
|
478 |
+
# vpad = height - orig_height
|
479 |
+
# wpad = width - orig_width
|
480 |
+
# top = vpad // 2
|
481 |
+
# bottom = vpad - top
|
482 |
+
# left = wpad // 2
|
483 |
+
# right = wpad - left
|
484 |
+
# #image = tf.pad(image, [[0, 0], [top, bottom], [left, right], [0, 0]], 'SYMMETRIC')
|
485 |
+
# image = torch.pad(image, [[0, 0], [0, vpad], [0, wpad], [0, 0]], 'reflect')
|
486 |
+
|
487 |
+
factor = quality_to_factor(quality)
|
488 |
+
|
489 |
+
compress = compress_jpeg(rounding=rounding, factor=factor).to(image.device)
|
490 |
+
decompress = decompress_jpeg(height, width, rounding=rounding, factor=factor).to(image.device)
|
491 |
+
|
492 |
+
y, cb, cr = compress(image)
|
493 |
+
recovered = decompress(y, cb, cr)
|
494 |
+
|
495 |
+
return recovered.contiguous()
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == '__main__':
|
499 |
+
''' test JPEG compress and decompress'''
|
500 |
+
# img = Image.open('house.jpg')
|
501 |
+
# img = np.array(img) / 255.
|
502 |
+
# img_r = np.transpose(img, [2, 0, 1])
|
503 |
+
# img_tensor = torch.from_numpy(img_r).unsqueeze(0).float()
|
504 |
+
|
505 |
+
# recover = jpeg_compress_decompress(img_tensor)
|
506 |
+
|
507 |
+
# recover_arr = recover.detach().squeeze(0).numpy()
|
508 |
+
# recover_arr = np.transpose(recover_arr, [1, 2, 0])
|
509 |
+
|
510 |
+
# plt.subplot(121)
|
511 |
+
# plt.imshow(img)
|
512 |
+
# plt.subplot(122)
|
513 |
+
# plt.imshow(recover_arr)
|
514 |
+
# plt.show()
|
515 |
+
|
516 |
+
''' test blur '''
|
517 |
+
# blur
|
518 |
+
|
519 |
+
img = Image.open('house.jpg')
|
520 |
+
img = np.array(img) / 255.
|
521 |
+
img_r = np.transpose(img, [2, 0, 1])
|
522 |
+
img_tensor = torch.from_numpy(img_r).unsqueeze(0).float()
|
523 |
+
print(img_tensor.shape)
|
524 |
+
|
525 |
+
N_blur=7
|
526 |
+
f = random_blur_kernel(probs=[.25, .25], N_blur=N_blur, sigrange_gauss=[1., 3.], sigrange_line=[.25, 1.], wmin_line=3)
|
527 |
+
# print(f.shape)
|
528 |
+
# print(type(f))
|
529 |
+
encoded_image = F.conv2d(img_tensor, f, bias=None, padding=int((N_blur-1)/2))
|
530 |
+
|
531 |
+
encoded_image = encoded_image.detach().squeeze(0).numpy()
|
532 |
+
encoded_image = np.transpose(encoded_image, [1, 2, 0])
|
533 |
+
|
534 |
+
plt.subplot(121)
|
535 |
+
plt.imshow(img)
|
536 |
+
plt.subplot(122)
|
537 |
+
plt.imshow(encoded_image)
|
538 |
+
plt.show()
|
539 |
+
|
flae/models.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
from torch.nn import functional as thf
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from ldm.util import instantiate_from_config
|
7 |
+
import einops
|
8 |
+
import kornia
|
9 |
+
import numpy as np
|
10 |
+
import torchvision
|
11 |
+
from contextlib import contextmanager
|
12 |
+
from ldm.modules.ema import LitEma
|
13 |
+
|
14 |
+
|
15 |
+
class FlAE(pl.LightningModule):
|
16 |
+
def __init__(self,
|
17 |
+
cover_key,
|
18 |
+
secret_key,
|
19 |
+
secret_len,
|
20 |
+
resolution,
|
21 |
+
secret_encoder_config,
|
22 |
+
secret_decoder_config,
|
23 |
+
loss_config,
|
24 |
+
noise_config='__none__',
|
25 |
+
ckpt_path="__none__",
|
26 |
+
use_ema=False
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.cover_key = cover_key
|
30 |
+
self.secret_key = secret_key
|
31 |
+
secret_encoder_config.params.secret_len = secret_len
|
32 |
+
secret_decoder_config.params.secret_len = secret_len
|
33 |
+
secret_encoder_config.params.resolution = resolution
|
34 |
+
secret_decoder_config.params.resolution = 224
|
35 |
+
self.encoder = instantiate_from_config(secret_encoder_config)
|
36 |
+
self.decoder = instantiate_from_config(secret_decoder_config)
|
37 |
+
self.loss_layer = instantiate_from_config(loss_config)
|
38 |
+
if noise_config != '__none__':
|
39 |
+
print('Using noise')
|
40 |
+
self.noise = instantiate_from_config(noise_config)
|
41 |
+
|
42 |
+
self.use_ema = use_ema
|
43 |
+
if self.use_ema:
|
44 |
+
print('Using EMA')
|
45 |
+
self.encoder_ema = LitEma(self.encoder)
|
46 |
+
self.decoder_ema = LitEma(self.decoder)
|
47 |
+
print(f"Keeping EMAs of {len(list(self.encoder_ema.buffers()) + list(self.decoder_ema.buffers()))}.")
|
48 |
+
|
49 |
+
if ckpt_path != "__none__":
|
50 |
+
self.init_from_ckpt(ckpt_path, ignore_keys=[])
|
51 |
+
|
52 |
+
# early training phase
|
53 |
+
self.fixed_img = None
|
54 |
+
self.fixed_secret = None
|
55 |
+
self.register_buffer("fixed_input", torch.tensor(True))
|
56 |
+
self.crop = kornia.augmentation.CenterCrop((224, 224), cropping_mode="resample") # early training phase
|
57 |
+
|
58 |
+
def init_from_ckpt(self, path, ignore_keys=list()):
|
59 |
+
sd = torch.load(path, map_location="cpu")["state_dict"]
|
60 |
+
keys = list(sd.keys())
|
61 |
+
for k in keys:
|
62 |
+
for ik in ignore_keys:
|
63 |
+
if k.startswith(ik):
|
64 |
+
print("Deleting key {} from state_dict.".format(k))
|
65 |
+
del sd[k]
|
66 |
+
self.load_state_dict(sd, strict=False)
|
67 |
+
print(f"Restored from {path}")
|
68 |
+
|
69 |
+
@contextmanager
|
70 |
+
def ema_scope(self, context=None):
|
71 |
+
if self.use_ema:
|
72 |
+
self.encoder_ema.store(self.encoder.parameters())
|
73 |
+
self.decoder_ema.store(self.decoder.parameters())
|
74 |
+
self.encoder_ema.copy_to(self.encoder)
|
75 |
+
self.decoder_ema.copy_to(self.decoder)
|
76 |
+
if context is not None:
|
77 |
+
print(f"{context}: Switched to EMA weights")
|
78 |
+
try:
|
79 |
+
yield None
|
80 |
+
finally:
|
81 |
+
if self.use_ema:
|
82 |
+
self.encoder_ema.restore(self.encoder.parameters())
|
83 |
+
self.decoder_ema.restore(self.decoder.parameters())
|
84 |
+
if context is not None:
|
85 |
+
print(f"{context}: Restored training weights")
|
86 |
+
|
87 |
+
def on_train_batch_end(self, *args, **kwargs):
|
88 |
+
if self.use_ema:
|
89 |
+
self.encoder_ema(self.encoder)
|
90 |
+
self.decoder_ema(self.decoder)
|
91 |
+
|
92 |
+
@torch.no_grad()
|
93 |
+
def get_input(self, batch, bs=None):
|
94 |
+
image = batch[self.cover_key]
|
95 |
+
secret = batch[self.secret_key]
|
96 |
+
if bs is not None:
|
97 |
+
image = image[:bs]
|
98 |
+
secret = secret[:bs]
|
99 |
+
else:
|
100 |
+
bs = image.shape[0]
|
101 |
+
# encode image 1st stage
|
102 |
+
image = einops.rearrange(image, "b h w c -> b c h w").contiguous()
|
103 |
+
|
104 |
+
# check if using fixed input (early training phase)
|
105 |
+
# if self.training and self.fixed_input:
|
106 |
+
if self.fixed_input:
|
107 |
+
if self.fixed_img is None: # first iteration
|
108 |
+
print('[TRAINING] Warmup - using fixed input image for now!')
|
109 |
+
self.fixed_img = image.detach().clone()[:bs]
|
110 |
+
self.fixed_secret = secret.detach().clone()[:bs] # use for log_images with fixed_input option only
|
111 |
+
image = self.fixed_img
|
112 |
+
new_bs = min(secret.shape[0], image.shape[0])
|
113 |
+
image, secret = image[:new_bs], secret[:new_bs]
|
114 |
+
|
115 |
+
out = [image, secret]
|
116 |
+
return out
|
117 |
+
|
118 |
+
def forward(self, cover, secret):
|
119 |
+
# return a tuple (stego, residual)
|
120 |
+
enc_out = self.encoder(cover, secret)
|
121 |
+
if self.encoder.return_residual:
|
122 |
+
return cover + enc_out, enc_out
|
123 |
+
else:
|
124 |
+
return enc_out, enc_out - cover
|
125 |
+
|
126 |
+
def shared_step(self, batch):
|
127 |
+
x, s = self.get_input(batch)
|
128 |
+
stego, residual = self(x, s)
|
129 |
+
if hasattr(self, "noise") and self.noise.is_activated():
|
130 |
+
stego_noised = self.noise(stego, self.global_step, p=0.9)
|
131 |
+
else:
|
132 |
+
stego_noised = self.crop(stego)
|
133 |
+
stego_noised = torch.clamp(stego_noised, -1, 1)
|
134 |
+
spred = self.decoder(stego_noised)
|
135 |
+
|
136 |
+
loss, loss_dict = self.loss_layer(x, stego, None, s, spred, self.global_step)
|
137 |
+
bit_acc = loss_dict["bit_acc"]
|
138 |
+
|
139 |
+
bit_acc_ = bit_acc.item()
|
140 |
+
|
141 |
+
if (bit_acc_ > 0.98) and (not self.fixed_input) and self.noise.is_activated():
|
142 |
+
self.loss_layer.activate_ramp(self.global_step)
|
143 |
+
|
144 |
+
if (bit_acc_ > 0.95) and (not self.fixed_input): # ramp up image loss at late training stage
|
145 |
+
if hasattr(self, 'noise') and (not self.noise.is_activated()):
|
146 |
+
self.noise.activate(self.global_step)
|
147 |
+
|
148 |
+
if (bit_acc_ > 0.9) and self.fixed_input: # execute only once
|
149 |
+
print(f'[TRAINING] High bit acc ({bit_acc_}) achieved, switch to full image dataset training.')
|
150 |
+
self.fixed_input = ~self.fixed_input
|
151 |
+
return loss, loss_dict
|
152 |
+
|
153 |
+
def training_step(self, batch, batch_idx):
|
154 |
+
loss, loss_dict = self.shared_step(batch)
|
155 |
+
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
156 |
+
self.log_dict(loss_dict, prog_bar=True,
|
157 |
+
logger=True, on_step=True, on_epoch=True)
|
158 |
+
|
159 |
+
self.log("global_step", self.global_step,
|
160 |
+
prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
161 |
+
# if self.use_scheduler:
|
162 |
+
# lr = self.optimizers().param_groups[0]['lr']
|
163 |
+
# self.log('lr_abs', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False)
|
164 |
+
|
165 |
+
return loss
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def validation_step(self, batch, batch_idx):
|
169 |
+
_, loss_dict_no_ema = self.shared_step(batch)
|
170 |
+
loss_dict_no_ema = {f"val/{key}": val for key, val in loss_dict_no_ema.items() if key != 'img_lw'}
|
171 |
+
with self.ema_scope():
|
172 |
+
_, loss_dict_ema = self.shared_step(batch)
|
173 |
+
loss_dict_ema = {'val/' + key + '_ema': loss_dict_ema[key] for key in loss_dict_ema}
|
174 |
+
self.log_dict(loss_dict_no_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
175 |
+
self.log_dict(loss_dict_ema, prog_bar=False, logger=True, on_step=False, on_epoch=True)
|
176 |
+
|
177 |
+
@torch.no_grad()
|
178 |
+
def log_images(self, batch, fixed_input=False, **kwargs):
|
179 |
+
log = dict()
|
180 |
+
if fixed_input and self.fixed_img is not None:
|
181 |
+
x, s = self.fixed_img, self.fixed_secret
|
182 |
+
else:
|
183 |
+
x, s = self.get_input(batch)
|
184 |
+
stego, residual = self(x, s)
|
185 |
+
if hasattr(self, 'noise') and self.noise.is_activated():
|
186 |
+
img_noise = self.noise(stego, self.global_step, p=1.0)
|
187 |
+
log['noised'] = img_noise
|
188 |
+
log['input'] = x
|
189 |
+
log['stego'] = stego
|
190 |
+
log['residual'] = (residual - residual.min()) / (residual.max() - residual.min() + 1e-8)*2 - 1
|
191 |
+
return log
|
192 |
+
|
193 |
+
def configure_optimizers(self):
|
194 |
+
lr = self.learning_rate
|
195 |
+
params = list(self.encoder.parameters()) + list(self.decoder.parameters())
|
196 |
+
optimizer = torch.optim.AdamW(params, lr=lr)
|
197 |
+
return optimizer
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
class SecretEncoder(nn.Module):
|
203 |
+
def __init__(self, resolution=256, secret_len=100, return_residual=False, act='tanh') -> None:
|
204 |
+
super().__init__()
|
205 |
+
self.secret_len = secret_len
|
206 |
+
self.return_residual = return_residual
|
207 |
+
self.act_fn = lambda x: torch.tanh(x) if act == 'tanh' else thf.sigmoid(x) * 2.0 -1.0
|
208 |
+
self.secret_dense = nn.Linear(secret_len, 16*16*3)
|
209 |
+
log_resolution = int(math.log(resolution, 2))
|
210 |
+
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
|
211 |
+
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
|
212 |
+
self.conv1 = nn.Conv2d(2 * 3, 32, 3, 1, 1)
|
213 |
+
self.conv2 = nn.Conv2d(32, 32, 3, 2, 1)
|
214 |
+
self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
|
215 |
+
self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
|
216 |
+
self.conv5 = nn.Conv2d(128, 256, 3, 2, 1)
|
217 |
+
self.pad6 = nn.ZeroPad2d((0, 1, 0, 1))
|
218 |
+
self.up6 = nn.Conv2d(256, 128, 2, 1)
|
219 |
+
self.upsample6 = nn.Upsample(scale_factor=(2, 2))
|
220 |
+
self.conv6 = nn.Conv2d(128 + 128, 128, 3, 1, 1)
|
221 |
+
self.pad7 = nn.ZeroPad2d((0, 1, 0, 1))
|
222 |
+
self.up7 = nn.Conv2d(128, 64, 2, 1)
|
223 |
+
self.upsample7 = nn.Upsample(scale_factor=(2, 2))
|
224 |
+
self.conv7 = nn.Conv2d(64 + 64, 64, 3, 1, 1)
|
225 |
+
self.pad8 = nn.ZeroPad2d((0, 1, 0, 1))
|
226 |
+
self.up8 = nn.Conv2d(64, 32, 2, 1)
|
227 |
+
self.upsample8 = nn.Upsample(scale_factor=(2, 2))
|
228 |
+
self.conv8 = nn.Conv2d(32 + 32, 32, 3, 1, 1)
|
229 |
+
self.pad9 = nn.ZeroPad2d((0, 1, 0, 1))
|
230 |
+
self.up9 = nn.Conv2d(32, 32, 2, 1)
|
231 |
+
self.upsample9 = nn.Upsample(scale_factor=(2, 2))
|
232 |
+
self.conv9 = nn.Conv2d(32 + 32 + 2 * 3, 32, 3, 1, 1)
|
233 |
+
self.conv10 = nn.Conv2d(32, 32, 3, 1, 1)
|
234 |
+
self.residual = nn.Conv2d(32, 3, 1)
|
235 |
+
|
236 |
+
def forward(self, image, secret):
|
237 |
+
fingerprint = thf.relu(self.secret_dense(secret))
|
238 |
+
fingerprint = fingerprint.view((-1, 3, 16, 16))
|
239 |
+
fingerprint_enlarged = self.secret_upsample(fingerprint)
|
240 |
+
# try:
|
241 |
+
inputs = torch.cat([fingerprint_enlarged, image], dim=1)
|
242 |
+
# except:
|
243 |
+
# print(fingerprint_enlarged.shape, image.shape, fingerprint.shape)
|
244 |
+
# import pdb; pdb.set_trace()
|
245 |
+
conv1 = thf.relu(self.conv1(inputs))
|
246 |
+
conv2 = thf.relu(self.conv2(conv1))
|
247 |
+
conv3 = thf.relu(self.conv3(conv2))
|
248 |
+
conv4 = thf.relu(self.conv4(conv3))
|
249 |
+
conv5 = thf.relu(self.conv5(conv4))
|
250 |
+
up6 = thf.relu(self.up6(self.pad6(self.upsample6(conv5))))
|
251 |
+
merge6 = torch.cat([conv4, up6], dim=1)
|
252 |
+
conv6 = thf.relu(self.conv6(merge6))
|
253 |
+
up7 = thf.relu(self.up7(self.pad7(self.upsample7(conv6))))
|
254 |
+
merge7 = torch.cat([conv3, up7], dim=1)
|
255 |
+
conv7 = thf.relu(self.conv7(merge7))
|
256 |
+
up8 = thf.relu(self.up8(self.pad8(self.upsample8(conv7))))
|
257 |
+
merge8 = torch.cat([conv2, up8], dim=1)
|
258 |
+
conv8 = thf.relu(self.conv8(merge8))
|
259 |
+
up9 = thf.relu(self.up9(self.pad9(self.upsample9(conv8))))
|
260 |
+
merge9 = torch.cat([conv1, up9, inputs], dim=1)
|
261 |
+
conv9 = thf.relu(self.conv9(merge9))
|
262 |
+
conv10 = thf.relu(self.conv10(conv9))
|
263 |
+
residual = self.residual(conv10)
|
264 |
+
residual = self.act_fn(residual)
|
265 |
+
return residual
|
266 |
+
|
267 |
+
|
268 |
+
class SecretEncoder1(nn.Module):
|
269 |
+
def __init__(self, resolution=256, secret_len=100) -> None:
|
270 |
+
pass
|
271 |
+
|
272 |
+
class SecretDecoder(nn.Module):
|
273 |
+
def __init__(self, arch='resnet18', resolution=224, secret_len=100):
|
274 |
+
super().__init__()
|
275 |
+
self.resolution = resolution
|
276 |
+
self.arch = arch
|
277 |
+
if arch == 'resnet18':
|
278 |
+
self.decoder = torchvision.models.resnet18(pretrained=True, progress=False)
|
279 |
+
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
|
280 |
+
elif arch == 'resnet50':
|
281 |
+
self.decoder = torchvision.models.resnet50(pretrained=True, progress=False)
|
282 |
+
self.decoder.fc = nn.Linear(self.decoder.fc.in_features, secret_len)
|
283 |
+
elif arch == 'simple':
|
284 |
+
self.decoder = SimpleCNN(resolution, secret_len)
|
285 |
+
else:
|
286 |
+
raise ValueError('Unknown architecture')
|
287 |
+
|
288 |
+
def forward(self, image):
|
289 |
+
if self.arch in ['resnet50', 'resnet18'] and image.shape[-1] > self.resolution:
|
290 |
+
image = thf.interpolate(image, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False)
|
291 |
+
x = self.decoder(image)
|
292 |
+
return x
|
293 |
+
|
294 |
+
|
295 |
+
class SimpleCNN(nn.Module):
|
296 |
+
def __init__(self, resolution=224, secret_len=100):
|
297 |
+
super().__init__()
|
298 |
+
self.resolution = resolution
|
299 |
+
self.IMAGE_CHANNELS = 3
|
300 |
+
self.decoder = nn.Sequential(
|
301 |
+
nn.Conv2d(self.IMAGE_CHANNELS, 32, (3, 3), 2, 1), # resolution / 2
|
302 |
+
nn.ReLU(),
|
303 |
+
nn.Conv2d(32, 32, 3, 1, 1),
|
304 |
+
nn.ReLU(),
|
305 |
+
nn.Conv2d(32, 64, 3, 2, 1), # resolution / 4
|
306 |
+
nn.ReLU(),
|
307 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
308 |
+
nn.ReLU(),
|
309 |
+
nn.Conv2d(64, 64, 3, 2, 1), # resolution / 8
|
310 |
+
nn.ReLU(),
|
311 |
+
nn.Conv2d(64, 128, 3, 2, 1), # resolution / 16
|
312 |
+
nn.ReLU(),
|
313 |
+
nn.Conv2d(128, 128, (3, 3), 2, 1), # resolution / 32
|
314 |
+
nn.ReLU(),
|
315 |
+
)
|
316 |
+
self.dense = nn.Sequential(
|
317 |
+
nn.Linear(resolution * resolution * 128 // 32 // 32, 512),
|
318 |
+
nn.ReLU(),
|
319 |
+
nn.Linear(512, secret_len),
|
320 |
+
)
|
321 |
+
|
322 |
+
def forward(self, image):
|
323 |
+
x = self.decoder(image)
|
324 |
+
x = x.view(-1, self.resolution * self.resolution * 128 // 32 // 32)
|
325 |
+
return self.dense(x)
|
flae/munit.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
|
3 |
+
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
|
4 |
+
"""
|
5 |
+
from torch import nn
|
6 |
+
from torch.autograd import Variable
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
try:
|
10 |
+
from itertools import izip as zip
|
11 |
+
except ImportError: # will be 3.x series
|
12 |
+
pass
|
13 |
+
|
14 |
+
##################################################################################
|
15 |
+
# Discriminator
|
16 |
+
##################################################################################
|
17 |
+
|
18 |
+
class MsImageDis(nn.Module):
|
19 |
+
# Multi-scale discriminator architecture
|
20 |
+
def __init__(self, input_dim, params):
|
21 |
+
super(MsImageDis, self).__init__()
|
22 |
+
self.n_layer = params['n_layer']
|
23 |
+
self.gan_type = params['gan_type']
|
24 |
+
self.dim = params['dim']
|
25 |
+
self.norm = params['norm']
|
26 |
+
self.activ = params['activ']
|
27 |
+
self.num_scales = params['num_scales']
|
28 |
+
self.pad_type = params['pad_type']
|
29 |
+
self.input_dim = input_dim
|
30 |
+
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
|
31 |
+
self.cnns = nn.ModuleList()
|
32 |
+
for _ in range(self.num_scales):
|
33 |
+
self.cnns.append(self._make_net())
|
34 |
+
|
35 |
+
def _make_net(self):
|
36 |
+
dim = self.dim
|
37 |
+
cnn_x = []
|
38 |
+
cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)]
|
39 |
+
for i in range(self.n_layer - 1):
|
40 |
+
cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)]
|
41 |
+
dim *= 2
|
42 |
+
cnn_x += [nn.Conv2d(dim, 1, 1, 1, 0)]
|
43 |
+
cnn_x = nn.Sequential(*cnn_x)
|
44 |
+
return cnn_x
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
outputs = []
|
48 |
+
for model in self.cnns:
|
49 |
+
outputs.append(model(x))
|
50 |
+
x = self.downsample(x)
|
51 |
+
return outputs
|
52 |
+
|
53 |
+
def calc_dis_loss(self, input_fake, input_real):
|
54 |
+
# calculate the loss to train D
|
55 |
+
outs0 = self.forward(input_fake)
|
56 |
+
outs1 = self.forward(input_real)
|
57 |
+
loss = 0
|
58 |
+
|
59 |
+
for it, (out0, out1) in enumerate(zip(outs0, outs1)):
|
60 |
+
if self.gan_type == 'lsgan':
|
61 |
+
loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2)
|
62 |
+
elif self.gan_type == 'nsgan':
|
63 |
+
all0 = Variable(torch.zeros_like(out0.data).cuda(), requires_grad=False)
|
64 |
+
all1 = Variable(torch.ones_like(out1.data).cuda(), requires_grad=False)
|
65 |
+
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) +
|
66 |
+
F.binary_cross_entropy(F.sigmoid(out1), all1))
|
67 |
+
else:
|
68 |
+
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
|
69 |
+
return loss
|
70 |
+
|
71 |
+
def calc_gen_loss(self, input_fake):
|
72 |
+
# calculate the loss to train G
|
73 |
+
outs0 = self.forward(input_fake)
|
74 |
+
loss = 0
|
75 |
+
for it, (out0) in enumerate(outs0):
|
76 |
+
if self.gan_type == 'lsgan':
|
77 |
+
loss += torch.mean((out0 - 1)**2) # LSGAN
|
78 |
+
elif self.gan_type == 'nsgan':
|
79 |
+
all1 = Variable(torch.ones_like(out0.data).cuda(), requires_grad=False)
|
80 |
+
loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1))
|
81 |
+
else:
|
82 |
+
assert 0, "Unsupported GAN type: {}".format(self.gan_type)
|
83 |
+
return loss
|
84 |
+
|
85 |
+
##################################################################################
|
86 |
+
# Generator
|
87 |
+
##################################################################################
|
88 |
+
|
89 |
+
class AdaINGen(nn.Module):
|
90 |
+
# AdaIN auto-encoder architecture
|
91 |
+
def __init__(self, input_dim, params):
|
92 |
+
super(AdaINGen, self).__init__()
|
93 |
+
dim = params['dim']
|
94 |
+
style_dim = params['style_dim']
|
95 |
+
n_downsample = params['n_downsample']
|
96 |
+
n_res = params['n_res']
|
97 |
+
activ = params['activ']
|
98 |
+
pad_type = params['pad_type']
|
99 |
+
mlp_dim = params['mlp_dim']
|
100 |
+
|
101 |
+
# style encoder
|
102 |
+
self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type)
|
103 |
+
|
104 |
+
# content encoder
|
105 |
+
self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
|
106 |
+
self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, input_dim, res_norm='adain', activ=activ, pad_type=pad_type)
|
107 |
+
|
108 |
+
# MLP to generate AdaIN parameters
|
109 |
+
self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ)
|
110 |
+
|
111 |
+
def forward(self, images):
|
112 |
+
# reconstruct an image
|
113 |
+
content, style_fake = self.encode(images)
|
114 |
+
images_recon = self.decode(content, style_fake)
|
115 |
+
return images_recon
|
116 |
+
|
117 |
+
def encode(self, images):
|
118 |
+
# encode an image to its content and style codes
|
119 |
+
style_fake = self.enc_style(images)
|
120 |
+
content = self.enc_content(images)
|
121 |
+
return content, style_fake
|
122 |
+
|
123 |
+
def decode(self, content, style):
|
124 |
+
# decode content and style codes to an image
|
125 |
+
adain_params = self.mlp(style)
|
126 |
+
self.assign_adain_params(adain_params, self.dec)
|
127 |
+
images = self.dec(content)
|
128 |
+
return images
|
129 |
+
|
130 |
+
def assign_adain_params(self, adain_params, model):
|
131 |
+
# assign the adain_params to the AdaIN layers in model
|
132 |
+
for m in model.modules():
|
133 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
134 |
+
mean = adain_params[:, :m.num_features]
|
135 |
+
std = adain_params[:, m.num_features:2*m.num_features]
|
136 |
+
m.bias = mean.contiguous().view(-1)
|
137 |
+
m.weight = std.contiguous().view(-1)
|
138 |
+
if adain_params.size(1) > 2*m.num_features:
|
139 |
+
adain_params = adain_params[:, 2*m.num_features:]
|
140 |
+
|
141 |
+
def get_num_adain_params(self, model):
|
142 |
+
# return the number of AdaIN parameters needed by the model
|
143 |
+
num_adain_params = 0
|
144 |
+
for m in model.modules():
|
145 |
+
if m.__class__.__name__ == "AdaptiveInstanceNorm2d":
|
146 |
+
num_adain_params += 2*m.num_features
|
147 |
+
return num_adain_params
|
148 |
+
|
149 |
+
|
150 |
+
class VAEGen(nn.Module):
|
151 |
+
# VAE architecture
|
152 |
+
def __init__(self, input_dim, params):
|
153 |
+
super(VAEGen, self).__init__()
|
154 |
+
dim = params['dim']
|
155 |
+
n_downsample = params['n_downsample']
|
156 |
+
n_res = params['n_res']
|
157 |
+
activ = params['activ']
|
158 |
+
pad_type = params['pad_type']
|
159 |
+
|
160 |
+
# content encoder
|
161 |
+
self.enc = ContentEncoder(n_downsample, n_res, input_dim, dim, 'in', activ, pad_type=pad_type)
|
162 |
+
self.dec = Decoder(n_downsample, n_res, self.enc.output_dim, input_dim, res_norm='in', activ=activ, pad_type=pad_type)
|
163 |
+
|
164 |
+
def forward(self, images):
|
165 |
+
# This is a reduced VAE implementation where we assume the outputs are multivariate Gaussian distribution with mean = hiddens and std_dev = all ones.
|
166 |
+
hiddens = self.encode(images)
|
167 |
+
if self.training == True:
|
168 |
+
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
|
169 |
+
images_recon = self.decode(hiddens + noise)
|
170 |
+
else:
|
171 |
+
images_recon = self.decode(hiddens)
|
172 |
+
return images_recon, hiddens
|
173 |
+
|
174 |
+
def encode(self, images):
|
175 |
+
hiddens = self.enc(images)
|
176 |
+
noise = Variable(torch.randn(hiddens.size()).cuda(hiddens.data.get_device()))
|
177 |
+
return hiddens, noise
|
178 |
+
|
179 |
+
def decode(self, hiddens):
|
180 |
+
images = self.dec(hiddens)
|
181 |
+
return images
|
182 |
+
|
183 |
+
|
184 |
+
##################################################################################
|
185 |
+
# Encoder and Decoders
|
186 |
+
##################################################################################
|
187 |
+
|
188 |
+
class StyleEncoder(nn.Module):
|
189 |
+
def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
|
190 |
+
super(StyleEncoder, self).__init__()
|
191 |
+
self.model = []
|
192 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
193 |
+
for i in range(2):
|
194 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
195 |
+
dim *= 2
|
196 |
+
for i in range(n_downsample - 2):
|
197 |
+
self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
198 |
+
self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling
|
199 |
+
self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
|
200 |
+
self.model = nn.Sequential(*self.model)
|
201 |
+
self.output_dim = dim
|
202 |
+
|
203 |
+
def forward(self, x):
|
204 |
+
return self.model(x)
|
205 |
+
|
206 |
+
class ContentEncoder(nn.Module):
|
207 |
+
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
|
208 |
+
super(ContentEncoder, self).__init__()
|
209 |
+
self.model = []
|
210 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
211 |
+
# downsampling blocks
|
212 |
+
for i in range(n_downsample):
|
213 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
214 |
+
dim *= 2
|
215 |
+
# residual blocks
|
216 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
217 |
+
self.model = nn.Sequential(*self.model)
|
218 |
+
self.output_dim = dim
|
219 |
+
|
220 |
+
def forward(self, x):
|
221 |
+
return self.model(x)
|
222 |
+
|
223 |
+
class Decoder(nn.Module):
|
224 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
|
225 |
+
super(Decoder, self).__init__()
|
226 |
+
|
227 |
+
self.model = []
|
228 |
+
# AdaIN residual blocks
|
229 |
+
self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)]
|
230 |
+
# upsampling blocks
|
231 |
+
for i in range(n_upsample):
|
232 |
+
self.model += [nn.Upsample(scale_factor=2),
|
233 |
+
Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
|
234 |
+
dim //= 2
|
235 |
+
# use reflection padding in the last conv layer
|
236 |
+
self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
|
237 |
+
self.model = nn.Sequential(*self.model)
|
238 |
+
|
239 |
+
def forward(self, x):
|
240 |
+
return self.model(x)
|
241 |
+
|
242 |
+
##################################################################################
|
243 |
+
# Sequential Models
|
244 |
+
##################################################################################
|
245 |
+
class ResBlocks(nn.Module):
|
246 |
+
def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
|
247 |
+
super(ResBlocks, self).__init__()
|
248 |
+
self.model = []
|
249 |
+
for i in range(num_blocks):
|
250 |
+
self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
|
251 |
+
self.model = nn.Sequential(*self.model)
|
252 |
+
|
253 |
+
def forward(self, x):
|
254 |
+
return self.model(x)
|
255 |
+
|
256 |
+
class MLP(nn.Module):
|
257 |
+
def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
|
258 |
+
|
259 |
+
super(MLP, self).__init__()
|
260 |
+
self.model = []
|
261 |
+
self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
|
262 |
+
for i in range(n_blk - 2):
|
263 |
+
self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
|
264 |
+
self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
|
265 |
+
self.model = nn.Sequential(*self.model)
|
266 |
+
|
267 |
+
def forward(self, x):
|
268 |
+
return self.model(x.view(x.size(0), -1))
|
269 |
+
|
270 |
+
##################################################################################
|
271 |
+
# Basic Blocks
|
272 |
+
##################################################################################
|
273 |
+
class ResBlock(nn.Module):
|
274 |
+
def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
|
275 |
+
super(ResBlock, self).__init__()
|
276 |
+
|
277 |
+
model = []
|
278 |
+
model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
|
279 |
+
model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
|
280 |
+
self.model = nn.Sequential(*model)
|
281 |
+
|
282 |
+
def forward(self, x):
|
283 |
+
residual = x
|
284 |
+
out = self.model(x)
|
285 |
+
out += residual
|
286 |
+
return out
|
287 |
+
|
288 |
+
class Conv2dBlock(nn.Module):
|
289 |
+
def __init__(self, input_dim ,output_dim, kernel_size, stride,
|
290 |
+
padding=0, norm='none', activation='relu', pad_type='zero'):
|
291 |
+
super(Conv2dBlock, self).__init__()
|
292 |
+
self.use_bias = True
|
293 |
+
# initialize padding
|
294 |
+
if pad_type == 'reflect':
|
295 |
+
self.pad = nn.ReflectionPad2d(padding)
|
296 |
+
elif pad_type == 'replicate':
|
297 |
+
self.pad = nn.ReplicationPad2d(padding)
|
298 |
+
elif pad_type == 'zero':
|
299 |
+
self.pad = nn.ZeroPad2d(padding)
|
300 |
+
else:
|
301 |
+
assert 0, "Unsupported padding type: {}".format(pad_type)
|
302 |
+
|
303 |
+
# initialize normalization
|
304 |
+
norm_dim = output_dim
|
305 |
+
if norm == 'bn':
|
306 |
+
self.norm = nn.BatchNorm2d(norm_dim)
|
307 |
+
elif norm == 'in':
|
308 |
+
#self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
|
309 |
+
self.norm = nn.InstanceNorm2d(norm_dim)
|
310 |
+
elif norm == 'ln':
|
311 |
+
self.norm = LayerNorm(norm_dim)
|
312 |
+
elif norm == 'adain':
|
313 |
+
self.norm = AdaptiveInstanceNorm2d(norm_dim)
|
314 |
+
elif norm == 'none' or norm == 'sn':
|
315 |
+
self.norm = None
|
316 |
+
else:
|
317 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
318 |
+
|
319 |
+
# initialize activation
|
320 |
+
if activation == 'relu':
|
321 |
+
self.activation = nn.ReLU(inplace=True)
|
322 |
+
elif activation == 'lrelu':
|
323 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
324 |
+
elif activation == 'prelu':
|
325 |
+
self.activation = nn.PReLU()
|
326 |
+
elif activation == 'selu':
|
327 |
+
self.activation = nn.SELU(inplace=True)
|
328 |
+
elif activation == 'tanh':
|
329 |
+
self.activation = nn.Tanh()
|
330 |
+
elif activation == 'none':
|
331 |
+
self.activation = None
|
332 |
+
else:
|
333 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
334 |
+
|
335 |
+
# initialize convolution
|
336 |
+
if norm == 'sn':
|
337 |
+
self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
|
338 |
+
else:
|
339 |
+
self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
|
340 |
+
|
341 |
+
def forward(self, x):
|
342 |
+
x = self.conv(self.pad(x))
|
343 |
+
if self.norm:
|
344 |
+
x = self.norm(x)
|
345 |
+
if self.activation:
|
346 |
+
x = self.activation(x)
|
347 |
+
return x
|
348 |
+
|
349 |
+
class LinearBlock(nn.Module):
|
350 |
+
def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
|
351 |
+
super(LinearBlock, self).__init__()
|
352 |
+
use_bias = True
|
353 |
+
# initialize fully connected layer
|
354 |
+
if norm == 'sn':
|
355 |
+
self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
|
356 |
+
else:
|
357 |
+
self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
|
358 |
+
|
359 |
+
# initialize normalization
|
360 |
+
norm_dim = output_dim
|
361 |
+
if norm == 'bn':
|
362 |
+
self.norm = nn.BatchNorm1d(norm_dim)
|
363 |
+
elif norm == 'in':
|
364 |
+
self.norm = nn.InstanceNorm1d(norm_dim)
|
365 |
+
elif norm == 'ln':
|
366 |
+
self.norm = LayerNorm(norm_dim)
|
367 |
+
elif norm == 'none' or norm == 'sn':
|
368 |
+
self.norm = None
|
369 |
+
else:
|
370 |
+
assert 0, "Unsupported normalization: {}".format(norm)
|
371 |
+
|
372 |
+
# initialize activation
|
373 |
+
if activation == 'relu':
|
374 |
+
self.activation = nn.ReLU(inplace=True)
|
375 |
+
elif activation == 'lrelu':
|
376 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
377 |
+
elif activation == 'prelu':
|
378 |
+
self.activation = nn.PReLU()
|
379 |
+
elif activation == 'selu':
|
380 |
+
self.activation = nn.SELU(inplace=True)
|
381 |
+
elif activation == 'tanh':
|
382 |
+
self.activation = nn.Tanh()
|
383 |
+
elif activation == 'none':
|
384 |
+
self.activation = None
|
385 |
+
else:
|
386 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
387 |
+
|
388 |
+
def forward(self, x):
|
389 |
+
out = self.fc(x)
|
390 |
+
if self.norm:
|
391 |
+
out = self.norm(out)
|
392 |
+
if self.activation:
|
393 |
+
out = self.activation(out)
|
394 |
+
return out
|
395 |
+
|
396 |
+
##################################################################################
|
397 |
+
# VGG network definition
|
398 |
+
##################################################################################
|
399 |
+
class Vgg16(nn.Module):
|
400 |
+
def __init__(self):
|
401 |
+
super(Vgg16, self).__init__()
|
402 |
+
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
403 |
+
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
404 |
+
|
405 |
+
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
406 |
+
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
407 |
+
|
408 |
+
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
409 |
+
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
410 |
+
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
411 |
+
|
412 |
+
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
413 |
+
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
414 |
+
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
415 |
+
|
416 |
+
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
417 |
+
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
418 |
+
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
419 |
+
|
420 |
+
def forward(self, X):
|
421 |
+
h = F.relu(self.conv1_1(X), inplace=True)
|
422 |
+
h = F.relu(self.conv1_2(h), inplace=True)
|
423 |
+
# relu1_2 = h
|
424 |
+
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
425 |
+
|
426 |
+
h = F.relu(self.conv2_1(h), inplace=True)
|
427 |
+
h = F.relu(self.conv2_2(h), inplace=True)
|
428 |
+
# relu2_2 = h
|
429 |
+
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
430 |
+
|
431 |
+
h = F.relu(self.conv3_1(h), inplace=True)
|
432 |
+
h = F.relu(self.conv3_2(h), inplace=True)
|
433 |
+
h = F.relu(self.conv3_3(h), inplace=True)
|
434 |
+
# relu3_3 = h
|
435 |
+
h = F.max_pool2d(h, kernel_size=2, stride=2)
|
436 |
+
|
437 |
+
h = F.relu(self.conv4_1(h), inplace=True)
|
438 |
+
h = F.relu(self.conv4_2(h), inplace=True)
|
439 |
+
h = F.relu(self.conv4_3(h), inplace=True)
|
440 |
+
# relu4_3 = h
|
441 |
+
|
442 |
+
h = F.relu(self.conv5_1(h), inplace=True)
|
443 |
+
h = F.relu(self.conv5_2(h), inplace=True)
|
444 |
+
h = F.relu(self.conv5_3(h), inplace=True)
|
445 |
+
relu5_3 = h
|
446 |
+
|
447 |
+
return relu5_3
|
448 |
+
# return [relu1_2, relu2_2, relu3_3, relu4_3]
|
449 |
+
|
450 |
+
##################################################################################
|
451 |
+
# Normalization layers
|
452 |
+
##################################################################################
|
453 |
+
class AdaptiveInstanceNorm2d(nn.Module):
|
454 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.1):
|
455 |
+
super(AdaptiveInstanceNorm2d, self).__init__()
|
456 |
+
self.num_features = num_features
|
457 |
+
self.eps = eps
|
458 |
+
self.momentum = momentum
|
459 |
+
# weight and bias are dynamically assigned
|
460 |
+
self.weight = None
|
461 |
+
self.bias = None
|
462 |
+
# just dummy buffers, not used
|
463 |
+
self.register_buffer('running_mean', torch.zeros(num_features))
|
464 |
+
self.register_buffer('running_var', torch.ones(num_features))
|
465 |
+
|
466 |
+
def forward(self, x):
|
467 |
+
assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!"
|
468 |
+
b, c = x.size(0), x.size(1)
|
469 |
+
running_mean = self.running_mean.repeat(b)
|
470 |
+
running_var = self.running_var.repeat(b)
|
471 |
+
|
472 |
+
# Apply instance norm
|
473 |
+
x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
|
474 |
+
|
475 |
+
out = F.batch_norm(
|
476 |
+
x_reshaped, running_mean, running_var, self.weight, self.bias,
|
477 |
+
True, self.momentum, self.eps)
|
478 |
+
|
479 |
+
return out.view(b, c, *x.size()[2:])
|
480 |
+
|
481 |
+
def __repr__(self):
|
482 |
+
return self.__class__.__name__ + '(' + str(self.num_features) + ')'
|
483 |
+
|
484 |
+
|
485 |
+
class LayerNorm(nn.Module):
|
486 |
+
def __init__(self, num_features, eps=1e-5, affine=True):
|
487 |
+
super(LayerNorm, self).__init__()
|
488 |
+
self.num_features = num_features
|
489 |
+
self.affine = affine
|
490 |
+
self.eps = eps
|
491 |
+
|
492 |
+
if self.affine:
|
493 |
+
self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
|
494 |
+
self.beta = nn.Parameter(torch.zeros(num_features))
|
495 |
+
|
496 |
+
def forward(self, x):
|
497 |
+
shape = [-1] + [1] * (x.dim() - 1)
|
498 |
+
# print(x.size())
|
499 |
+
if x.size(0) == 1:
|
500 |
+
# These two lines run much faster in pytorch 0.4 than the two lines listed below.
|
501 |
+
mean = x.view(-1).mean().view(*shape)
|
502 |
+
std = x.view(-1).std().view(*shape)
|
503 |
+
else:
|
504 |
+
mean = x.view(x.size(0), -1).mean(1).view(*shape)
|
505 |
+
std = x.view(x.size(0), -1).std(1).view(*shape)
|
506 |
+
|
507 |
+
x = (x - mean) / (std + self.eps)
|
508 |
+
|
509 |
+
if self.affine:
|
510 |
+
shape = [1, -1] + [1] * (x.dim() - 2)
|
511 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
512 |
+
return x
|
513 |
+
|
514 |
+
def l2normalize(v, eps=1e-12):
|
515 |
+
return v / (v.norm() + eps)
|
516 |
+
|
517 |
+
|
518 |
+
class SpectralNorm(nn.Module):
|
519 |
+
"""
|
520 |
+
Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
|
521 |
+
and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
|
522 |
+
"""
|
523 |
+
def __init__(self, module, name='weight', power_iterations=1):
|
524 |
+
super(SpectralNorm, self).__init__()
|
525 |
+
self.module = module
|
526 |
+
self.name = name
|
527 |
+
self.power_iterations = power_iterations
|
528 |
+
if not self._made_params():
|
529 |
+
self._make_params()
|
530 |
+
|
531 |
+
def _update_u_v(self):
|
532 |
+
u = getattr(self.module, self.name + "_u")
|
533 |
+
v = getattr(self.module, self.name + "_v")
|
534 |
+
w = getattr(self.module, self.name + "_bar")
|
535 |
+
|
536 |
+
height = w.data.shape[0]
|
537 |
+
for _ in range(self.power_iterations):
|
538 |
+
v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data))
|
539 |
+
u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data))
|
540 |
+
|
541 |
+
# sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
|
542 |
+
sigma = u.dot(w.view(height, -1).mv(v))
|
543 |
+
setattr(self.module, self.name, w / sigma.expand_as(w))
|
544 |
+
|
545 |
+
def _made_params(self):
|
546 |
+
try:
|
547 |
+
u = getattr(self.module, self.name + "_u")
|
548 |
+
v = getattr(self.module, self.name + "_v")
|
549 |
+
w = getattr(self.module, self.name + "_bar")
|
550 |
+
return True
|
551 |
+
except AttributeError:
|
552 |
+
return False
|
553 |
+
|
554 |
+
|
555 |
+
def _make_params(self):
|
556 |
+
w = getattr(self.module, self.name)
|
557 |
+
|
558 |
+
height = w.data.shape[0]
|
559 |
+
width = w.view(height, -1).data.shape[1]
|
560 |
+
|
561 |
+
u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
|
562 |
+
v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
|
563 |
+
u.data = l2normalize(u.data)
|
564 |
+
v.data = l2normalize(v.data)
|
565 |
+
w_bar = nn.Parameter(w.data)
|
566 |
+
|
567 |
+
del self.module._parameters[self.name]
|
568 |
+
|
569 |
+
self.module.register_parameter(self.name + "_u", u)
|
570 |
+
self.module.register_parameter(self.name + "_v", v)
|
571 |
+
self.module.register_parameter(self.name + "_bar", w_bar)
|
572 |
+
|
573 |
+
|
574 |
+
def forward(self, *args):
|
575 |
+
self._update_u_v()
|
576 |
+
return self.module.forward(*args)
|
flae/unet.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
from torch.autograd import Variable
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from .munit import ResBlocks, Conv2dBlock
|
6 |
+
import math
|
7 |
+
|
8 |
+
|
9 |
+
class Unet(nn.Module):
|
10 |
+
def __init__(self, resolution=256, secret_len=100, return_residual=False) -> None:
|
11 |
+
super().__init__()
|
12 |
+
self.secret_len = secret_len
|
13 |
+
self.return_residual = return_residual
|
14 |
+
self.secret_dense = nn.Linear(secret_len, 16*16*3)
|
15 |
+
log_resolution = int(math.log(resolution, 2))
|
16 |
+
assert resolution == 2 ** log_resolution, f"Image resolution must be a power of 2, got {resolution}."
|
17 |
+
self.secret_upsample = nn.Upsample(scale_factor=(2**(log_resolution-4), 2**(log_resolution-4)))
|
18 |
+
|
19 |
+
self.enc = Encoder(2, 4, 6, 64, 'bn' , 'relu', 'reflect')
|
20 |
+
self.dec = Decoder(2, 4, self.enc.output_dim, 3, 'bn', 'relu', 'reflect')
|
21 |
+
|
22 |
+
def forward(self, image, secret):
|
23 |
+
# import pdb; pdb.set_trace()
|
24 |
+
fingerprint = F.relu(self.secret_dense(secret))
|
25 |
+
fingerprint = fingerprint.view((-1, 3, 16, 16))
|
26 |
+
fingerprint_enlarged = self.secret_upsample(fingerprint)
|
27 |
+
inputs = torch.cat([fingerprint_enlarged, image], dim=1)
|
28 |
+
emb = self.enc(inputs)
|
29 |
+
# import pdb; pdb.set_trace()
|
30 |
+
out = self.dec(emb)
|
31 |
+
return out
|
32 |
+
|
33 |
+
class Encoder(nn.Module):
|
34 |
+
def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type):
|
35 |
+
super().__init__()
|
36 |
+
self.model = []
|
37 |
+
self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)]
|
38 |
+
# downsampling blocks
|
39 |
+
for i in range(n_downsample):
|
40 |
+
self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)]
|
41 |
+
dim *= 2
|
42 |
+
# residual blocks
|
43 |
+
self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)]
|
44 |
+
# self.model = nn.(*self.model)
|
45 |
+
self.model = nn.ModuleList(self.model)
|
46 |
+
self.output_dim = dim
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
out = []
|
50 |
+
for block in self.model:
|
51 |
+
x = block(x)
|
52 |
+
out.append(x)
|
53 |
+
# print(x.shape)
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
class Decoder(nn.Module):
|
58 |
+
def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'):
|
59 |
+
super(Decoder, self).__init__()
|
60 |
+
|
61 |
+
self.model = []
|
62 |
+
# AdaIN residual blocks
|
63 |
+
self.model += [DecoderBlock('resblock', n_res, dim, res_norm, activ, pad_type=pad_type)]
|
64 |
+
# upsampling blocks
|
65 |
+
for i in range(n_upsample):
|
66 |
+
self.model += [DecoderBlock('upsample', dim, dim//2,'bn', activ, pad_type)
|
67 |
+
]
|
68 |
+
dim //= 2
|
69 |
+
# use reflection padding in the last conv layer
|
70 |
+
self.output_layer = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)
|
71 |
+
# self.model = nn.Sequential(*self.model)
|
72 |
+
self.model = nn.ModuleList(self.model)
|
73 |
+
|
74 |
+
def forward(self, x):
|
75 |
+
x1 = x.pop()
|
76 |
+
for block in self.model:
|
77 |
+
x2 = x.pop()
|
78 |
+
# print(x1.shape, x2.shape)
|
79 |
+
x1 = block(x1, x2)
|
80 |
+
x1 = self.output_layer(x1)
|
81 |
+
return x1
|
82 |
+
|
83 |
+
|
84 |
+
class Merge(nn.Module):
|
85 |
+
def __init__(self, dim, activation='relu'):
|
86 |
+
super().__init__()
|
87 |
+
self.conv = nn.Conv2d(2*dim, dim, 3, 1, 1)
|
88 |
+
# initialize activation
|
89 |
+
if activation == 'relu':
|
90 |
+
self.activation = nn.ReLU(inplace=True)
|
91 |
+
elif activation == 'lrelu':
|
92 |
+
self.activation = nn.LeakyReLU(0.2, inplace=True)
|
93 |
+
elif activation == 'prelu':
|
94 |
+
self.activation = nn.PReLU()
|
95 |
+
elif activation == 'selu':
|
96 |
+
self.activation = nn.SELU(inplace=True)
|
97 |
+
elif activation == 'tanh':
|
98 |
+
self.activation = nn.Tanh()
|
99 |
+
elif activation == 'none':
|
100 |
+
self.activation = None
|
101 |
+
else:
|
102 |
+
assert 0, "Unsupported activation: {}".format(activation)
|
103 |
+
def forward(self, x1, x2):
|
104 |
+
x = torch.cat([x1, x2], dim=1) # 2xdim
|
105 |
+
x = self.conv(x) # B,dim,H,W
|
106 |
+
x = self.activation(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
class DecoderBlock(nn.Module):
|
110 |
+
def __init__(self, block_type, in_dim, out_dim, norm, activ='relu', pad_type='reflect'):
|
111 |
+
super().__init__()
|
112 |
+
assert block_type in ['resblock', 'upsample']
|
113 |
+
if block_type == 'resblock':
|
114 |
+
self.core_layer = ResBlocks(in_dim, out_dim, norm, activ, pad_type=pad_type)
|
115 |
+
else:
|
116 |
+
assert out_dim == in_dim//2
|
117 |
+
self.core_layer = nn.Sequential(nn.Upsample(scale_factor=2),
|
118 |
+
Conv2dBlock(in_dim, out_dim, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type))
|
119 |
+
self.merge = Merge(out_dim, activ)
|
120 |
+
|
121 |
+
def forward(self, x1, x2):
|
122 |
+
x1 = self.core_layer(x1)
|
123 |
+
return self.merge(x1, x2)
|
ldm/modules/ema.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class LitEma(nn.Module):
|
6 |
+
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
7 |
+
super().__init__()
|
8 |
+
if decay < 0.0 or decay > 1.0:
|
9 |
+
raise ValueError('Decay must be between 0 and 1')
|
10 |
+
|
11 |
+
self.m_name2s_name = {}
|
12 |
+
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
13 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int) if use_num_upates
|
14 |
+
else torch.tensor(-1, dtype=torch.int))
|
15 |
+
|
16 |
+
for name, p in model.named_parameters():
|
17 |
+
if p.requires_grad:
|
18 |
+
# remove as '.'-character is not allowed in buffers
|
19 |
+
s_name = name.replace('.', '')
|
20 |
+
self.m_name2s_name.update({name: s_name})
|
21 |
+
self.register_buffer(s_name, p.clone().detach().data)
|
22 |
+
|
23 |
+
self.collected_params = []
|
24 |
+
|
25 |
+
def reset_num_updates(self):
|
26 |
+
del self.num_updates
|
27 |
+
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
28 |
+
|
29 |
+
def forward(self, model):
|
30 |
+
decay = self.decay
|
31 |
+
|
32 |
+
if self.num_updates >= 0:
|
33 |
+
self.num_updates += 1
|
34 |
+
decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates))
|
35 |
+
|
36 |
+
one_minus_decay = 1.0 - decay
|
37 |
+
|
38 |
+
with torch.no_grad():
|
39 |
+
m_param = dict(model.named_parameters())
|
40 |
+
shadow_params = dict(self.named_buffers())
|
41 |
+
|
42 |
+
for key in m_param:
|
43 |
+
if m_param[key].requires_grad:
|
44 |
+
sname = self.m_name2s_name[key]
|
45 |
+
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
46 |
+
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
47 |
+
else:
|
48 |
+
assert not key in self.m_name2s_name
|
49 |
+
|
50 |
+
def copy_to(self, model):
|
51 |
+
m_param = dict(model.named_parameters())
|
52 |
+
shadow_params = dict(self.named_buffers())
|
53 |
+
for key in m_param:
|
54 |
+
if m_param[key].requires_grad:
|
55 |
+
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
56 |
+
else:
|
57 |
+
assert not key in self.m_name2s_name
|
58 |
+
|
59 |
+
def store(self, parameters):
|
60 |
+
"""
|
61 |
+
Save the current parameters for restoring later.
|
62 |
+
Args:
|
63 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
64 |
+
temporarily stored.
|
65 |
+
"""
|
66 |
+
self.collected_params = [param.clone() for param in parameters]
|
67 |
+
|
68 |
+
def restore(self, parameters):
|
69 |
+
"""
|
70 |
+
Restore the parameters stored with the `store` method.
|
71 |
+
Useful to validate the model with EMA parameters without affecting the
|
72 |
+
original optimization process. Store the parameters before the
|
73 |
+
`copy_to` method. After validation (or model saving), use this to
|
74 |
+
restore the former parameters.
|
75 |
+
Args:
|
76 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
77 |
+
updated with the stored parameters.
|
78 |
+
"""
|
79 |
+
for c_param, param in zip(self.collected_params, parameters):
|
80 |
+
param.data.copy_(c_param.data)
|
ldm/util.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import importlib
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import optim
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from inspect import isfunction
|
8 |
+
from PIL import Image, ImageDraw, ImageFont
|
9 |
+
|
10 |
+
|
11 |
+
def log_txt_as_img(wh, xc, size=10):
|
12 |
+
# wh a tuple of (width, height)
|
13 |
+
# xc a list of captions to plot
|
14 |
+
b = len(xc)
|
15 |
+
txts = list()
|
16 |
+
for bi in range(b):
|
17 |
+
txt = Image.new("RGB", wh, color="white")
|
18 |
+
draw = ImageDraw.Draw(txt)
|
19 |
+
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
20 |
+
nc = int(40 * (wh[0] / 256))
|
21 |
+
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
22 |
+
|
23 |
+
try:
|
24 |
+
draw.text((0, 0), lines, fill="black", font=font)
|
25 |
+
except UnicodeEncodeError:
|
26 |
+
print("Cant encode string for logging. Skipping.")
|
27 |
+
|
28 |
+
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
29 |
+
txts.append(txt)
|
30 |
+
txts = np.stack(txts)
|
31 |
+
txts = torch.tensor(txts)
|
32 |
+
return txts
|
33 |
+
|
34 |
+
|
35 |
+
def ismap(x):
|
36 |
+
if not isinstance(x, torch.Tensor):
|
37 |
+
return False
|
38 |
+
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
39 |
+
|
40 |
+
|
41 |
+
def isimage(x):
|
42 |
+
if not isinstance(x,torch.Tensor):
|
43 |
+
return False
|
44 |
+
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
45 |
+
|
46 |
+
|
47 |
+
def exists(x):
|
48 |
+
return x is not None
|
49 |
+
|
50 |
+
|
51 |
+
def default(val, d):
|
52 |
+
if exists(val):
|
53 |
+
return val
|
54 |
+
return d() if isfunction(d) else d
|
55 |
+
|
56 |
+
|
57 |
+
def mean_flat(tensor):
|
58 |
+
"""
|
59 |
+
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
60 |
+
Take the mean over all non-batch dimensions.
|
61 |
+
"""
|
62 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
63 |
+
|
64 |
+
|
65 |
+
def count_params(model, verbose=False):
|
66 |
+
total_params = sum(p.numel() for p in model.parameters())
|
67 |
+
if verbose:
|
68 |
+
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
69 |
+
return total_params
|
70 |
+
|
71 |
+
|
72 |
+
def instantiate_from_config(config):
|
73 |
+
if not "target" in config:
|
74 |
+
if config == '__is_first_stage__':
|
75 |
+
return None
|
76 |
+
elif config == "__is_unconditional__":
|
77 |
+
return None
|
78 |
+
raise KeyError("Expected key `target` to instantiate.")
|
79 |
+
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
80 |
+
|
81 |
+
|
82 |
+
def get_obj_from_str(string, reload=False):
|
83 |
+
module, cls = string.rsplit(".", 1)
|
84 |
+
if reload:
|
85 |
+
module_imp = importlib.import_module(module)
|
86 |
+
importlib.reload(module_imp)
|
87 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
88 |
+
|
89 |
+
|
90 |
+
class AdamWwithEMAandWings(optim.Optimizer):
|
91 |
+
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
92 |
+
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
93 |
+
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
94 |
+
ema_power=1., param_names=()):
|
95 |
+
"""AdamW that saves EMA versions of the parameters."""
|
96 |
+
if not 0.0 <= lr:
|
97 |
+
raise ValueError("Invalid learning rate: {}".format(lr))
|
98 |
+
if not 0.0 <= eps:
|
99 |
+
raise ValueError("Invalid epsilon value: {}".format(eps))
|
100 |
+
if not 0.0 <= betas[0] < 1.0:
|
101 |
+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
102 |
+
if not 0.0 <= betas[1] < 1.0:
|
103 |
+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
104 |
+
if not 0.0 <= weight_decay:
|
105 |
+
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
106 |
+
if not 0.0 <= ema_decay <= 1.0:
|
107 |
+
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
108 |
+
defaults = dict(lr=lr, betas=betas, eps=eps,
|
109 |
+
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
110 |
+
ema_power=ema_power, param_names=param_names)
|
111 |
+
super().__init__(params, defaults)
|
112 |
+
|
113 |
+
def __setstate__(self, state):
|
114 |
+
super().__setstate__(state)
|
115 |
+
for group in self.param_groups:
|
116 |
+
group.setdefault('amsgrad', False)
|
117 |
+
|
118 |
+
@torch.no_grad()
|
119 |
+
def step(self, closure=None):
|
120 |
+
"""Performs a single optimization step.
|
121 |
+
Args:
|
122 |
+
closure (callable, optional): A closure that reevaluates the model
|
123 |
+
and returns the loss.
|
124 |
+
"""
|
125 |
+
loss = None
|
126 |
+
if closure is not None:
|
127 |
+
with torch.enable_grad():
|
128 |
+
loss = closure()
|
129 |
+
|
130 |
+
for group in self.param_groups:
|
131 |
+
params_with_grad = []
|
132 |
+
grads = []
|
133 |
+
exp_avgs = []
|
134 |
+
exp_avg_sqs = []
|
135 |
+
ema_params_with_grad = []
|
136 |
+
state_sums = []
|
137 |
+
max_exp_avg_sqs = []
|
138 |
+
state_steps = []
|
139 |
+
amsgrad = group['amsgrad']
|
140 |
+
beta1, beta2 = group['betas']
|
141 |
+
ema_decay = group['ema_decay']
|
142 |
+
ema_power = group['ema_power']
|
143 |
+
|
144 |
+
for p in group['params']:
|
145 |
+
if p.grad is None:
|
146 |
+
continue
|
147 |
+
params_with_grad.append(p)
|
148 |
+
if p.grad.is_sparse:
|
149 |
+
raise RuntimeError('AdamW does not support sparse gradients')
|
150 |
+
grads.append(p.grad)
|
151 |
+
|
152 |
+
state = self.state[p]
|
153 |
+
|
154 |
+
# State initialization
|
155 |
+
if len(state) == 0:
|
156 |
+
state['step'] = 0
|
157 |
+
# Exponential moving average of gradient values
|
158 |
+
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
159 |
+
# Exponential moving average of squared gradient values
|
160 |
+
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
161 |
+
if amsgrad:
|
162 |
+
# Maintains max of all exp. moving avg. of sq. grad. values
|
163 |
+
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
164 |
+
# Exponential moving average of parameter values
|
165 |
+
state['param_exp_avg'] = p.detach().float().clone()
|
166 |
+
|
167 |
+
exp_avgs.append(state['exp_avg'])
|
168 |
+
exp_avg_sqs.append(state['exp_avg_sq'])
|
169 |
+
ema_params_with_grad.append(state['param_exp_avg'])
|
170 |
+
|
171 |
+
if amsgrad:
|
172 |
+
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
173 |
+
|
174 |
+
# update the steps for each param group update
|
175 |
+
state['step'] += 1
|
176 |
+
# record the step after step update
|
177 |
+
state_steps.append(state['step'])
|
178 |
+
|
179 |
+
optim._functional.adamw(params_with_grad,
|
180 |
+
grads,
|
181 |
+
exp_avgs,
|
182 |
+
exp_avg_sqs,
|
183 |
+
max_exp_avg_sqs,
|
184 |
+
state_steps,
|
185 |
+
amsgrad=amsgrad,
|
186 |
+
beta1=beta1,
|
187 |
+
beta2=beta2,
|
188 |
+
lr=group['lr'],
|
189 |
+
weight_decay=group['weight_decay'],
|
190 |
+
eps=group['eps'],
|
191 |
+
maximize=False)
|
192 |
+
|
193 |
+
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
194 |
+
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
195 |
+
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
196 |
+
|
197 |
+
return loss
|
pages/Extract_Secret.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
streamlit app demo
|
5 |
+
how to run:
|
6 |
+
streamlit run app.py --server.port 8501
|
7 |
+
|
8 |
+
@author: Tu Bui @surrey.ac.uk
|
9 |
+
"""
|
10 |
+
import os, sys, torch
|
11 |
+
import inspect
|
12 |
+
cdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
|
13 |
+
sys.path.insert(1, os.path.join(cdir, '../'))
|
14 |
+
import argparse
|
15 |
+
from pathlib import Path
|
16 |
+
import numpy as np
|
17 |
+
import pickle
|
18 |
+
import pytorch_lightning as pl
|
19 |
+
from torchvision import transforms
|
20 |
+
import argparse
|
21 |
+
from ldm.util import instantiate_from_config
|
22 |
+
from omegaconf import OmegaConf
|
23 |
+
from PIL import Image
|
24 |
+
from tools.augment_imagenetc import RandomImagenetC
|
25 |
+
from cldm.transformations2 import TransformNet
|
26 |
+
from io import BytesIO
|
27 |
+
from tools.helpers import welcome_message
|
28 |
+
from tools.ecc import BCH, RSC
|
29 |
+
import streamlit as st
|
30 |
+
from Embed_Secret import parse_st_args, load_ecc, load_model, decode_secret, to_bytes, model_names
|
31 |
+
|
32 |
+
|
33 |
+
def app(args):
|
34 |
+
st.title('Watermarking Demo')
|
35 |
+
# setup model
|
36 |
+
model_name = st.selectbox("Choose the model", model_names)
|
37 |
+
model, tform_emb, tform_det, secret_len = load_model(model_name, args)
|
38 |
+
display_width = 300
|
39 |
+
ecc = load_ecc('BCH', secret_len)
|
40 |
+
noise = TransformNet(p=1.0, crop_mode='resized_crop')
|
41 |
+
noise_names = noise.optional_names
|
42 |
+
|
43 |
+
# setup st
|
44 |
+
st.subheader("Input")
|
45 |
+
image_file = None
|
46 |
+
image_file = st.file_uploader("Upload stego image", type=["png","jpg","jpeg"])
|
47 |
+
if image_file is not None:
|
48 |
+
im = Image.open(image_file).convert('RGB')
|
49 |
+
ext = image_file.name.split('.')[-1]
|
50 |
+
st.image(im, width=display_width)
|
51 |
+
|
52 |
+
|
53 |
+
# add crop
|
54 |
+
st.subheader("Corruptions")
|
55 |
+
crop_button = st.button('Regenerate Crop/Flip/Resize', key='crop')
|
56 |
+
if image_file is not None:
|
57 |
+
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
|
58 |
+
if crop_button:
|
59 |
+
im_crop = noise.apply_transform_on_pil_image(im, 'Fixed Augment')
|
60 |
+
# st.image(im_crop, width=display_width)
|
61 |
+
|
62 |
+
# add noise source 1
|
63 |
+
corrupt_method1 = st.selectbox("Choose noise source #1", ['None'] + noise_names, key='noise1')
|
64 |
+
if image_file is not None:
|
65 |
+
if corrupt_method1=='None':
|
66 |
+
im_noise1 = im_crop
|
67 |
+
else:
|
68 |
+
im_noise1 = noise.apply_transform_on_pil_image(im_crop, corrupt_method1)
|
69 |
+
# st.image(im_noise1, width=display_width)
|
70 |
+
|
71 |
+
# add noise source 2
|
72 |
+
corrupt_method2 = st.selectbox("Choose noise source #2", ['None'] + noise_names, key='noise2')
|
73 |
+
if image_file is not None:
|
74 |
+
if corrupt_method2=='None':
|
75 |
+
im_noise2 = im_noise1
|
76 |
+
else:
|
77 |
+
im_noise2 = noise.apply_transform_on_pil_image(im_noise1, corrupt_method2)
|
78 |
+
|
79 |
+
st.subheader("Output")
|
80 |
+
if image_file is not None:
|
81 |
+
st.image(im_noise2, width=display_width)
|
82 |
+
mime='image/jpeg' if ext=='jpg' else f'image/{ext}'
|
83 |
+
im_noise2_bytes = to_bytes(np.uint8(im_noise2), mime)
|
84 |
+
st.download_button(label='Download image', data=im_noise2_bytes, file_name=f'corrupted.{ext}', mime=mime)
|
85 |
+
|
86 |
+
# prediction
|
87 |
+
st.subheader('Extract Secret From Output')
|
88 |
+
status = st.empty()
|
89 |
+
if image_file is not None:
|
90 |
+
secret_pred = decode_secret(model_name, model, im_noise2, tform_det)
|
91 |
+
secret_decoded = ecc.decode_text(secret_pred)[0]
|
92 |
+
status.markdown(f'Predicted secret: **{secret_decoded}**', unsafe_allow_html=True)
|
93 |
+
|
94 |
+
# bit acc
|
95 |
+
st.subheader('Accuracy')
|
96 |
+
secret_text = st.text_input('Input groundtruth secret')
|
97 |
+
bit_acc_status = st.empty()
|
98 |
+
if image_file is not None and secret_text:
|
99 |
+
secret = ecc.encode_text([secret_text]) # (1, 100)
|
100 |
+
bit_acc = (secret_pred == secret).mean()
|
101 |
+
# bit_acc_status.markdown('**Bit Accuracy**: {:.2f}%'.format(bit_acc*100), unsafe_allow_html=True)
|
102 |
+
word_acc = int(secret_decoded == secret_text)
|
103 |
+
bit_acc_status.markdown(f'Bit Accuracy: **{bit_acc*100:.2f}%**<br />Word Accuracy: **{word_acc}**', unsafe_allow_html=True)
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
args = parse_st_args()
|
107 |
+
app(args)
|
108 |
+
|
tools/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import *
|
2 |
+
from .hparams import HParams
|
3 |
+
from .slack_bot import Notifier
|
tools/augment_imagenetc.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
wrapper for imagenet-c transformations
|
5 |
+
@author: Tu Bui @surrey.ac.uk
|
6 |
+
"""
|
7 |
+
from __future__ import absolute_import
|
8 |
+
from __future__ import division
|
9 |
+
from __future__ import print_function
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import random
|
13 |
+
import numpy as np
|
14 |
+
from PIL import Image
|
15 |
+
from imagenet_c import corrupt, corruption_dict
|
16 |
+
|
17 |
+
|
18 |
+
class IdentityAugment(object):
|
19 |
+
def __call__(self, x):
|
20 |
+
return x
|
21 |
+
|
22 |
+
def __repr__(self):
|
23 |
+
s = f'()'
|
24 |
+
return self.__class__.__name__ + s
|
25 |
+
|
26 |
+
class RandomImagenetC(object):
|
27 |
+
# transform id 5 (motion blur) and 7 (snow) requires WandImage which is not fork-safe, while id 4 (glass blur) and 6 (zoom blur) are super slow thus we move it to validation (unseen), 12 (elastic transform) is non realistic
|
28 |
+
methods = {'train': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18]),#np.arange(15),
|
29 |
+
'val': np.array([4, 5, 6, 7, 12]),
|
30 |
+
'test': np.array([0,1,2,3,8,9,10,11,13,14,15, 16, 17, 18])
|
31 |
+
}
|
32 |
+
method_names = list(corruption_dict.keys())
|
33 |
+
def __init__(self, min_severity=1, max_severity=5, phase='all', p=1.0,n=19):
|
34 |
+
assert phase in ['train', 'val', 'test', 'all'], ValueError(f'{phase} not recognised. Must be one of [train, val, all]')
|
35 |
+
if phase == 'all':
|
36 |
+
self.corrupt_ids = np.concatenate(list(self.methods.values()))
|
37 |
+
else:
|
38 |
+
self.corrupt_ids = self.methods[phase]
|
39 |
+
self.corrupt_ids = self.corrupt_ids[:n] # first n tforms
|
40 |
+
self.phase = phase
|
41 |
+
self.severity = np.arange(min_severity, max_severity+1)
|
42 |
+
self.p = p # probability to apply a transformation
|
43 |
+
|
44 |
+
def __call__(self, x, corrupt_id=None, corrupt_strength=None):
|
45 |
+
# input: x PIL image
|
46 |
+
if corrupt_id is None:
|
47 |
+
if len(self.corrupt_ids)==0: # do nothing
|
48 |
+
return x
|
49 |
+
corrupt_id = np.random.choice(self.corrupt_ids)
|
50 |
+
else:
|
51 |
+
assert corrupt_id in range(19)
|
52 |
+
|
53 |
+
severity = np.random.choice(self.severity) if corrupt_strength is None else corrupt_strength
|
54 |
+
assert severity in self.severity, f'Error! Corrupt strength {severity} isnt supported.'
|
55 |
+
|
56 |
+
if np.random.rand() < self.p:
|
57 |
+
org_size = x.size
|
58 |
+
x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
|
59 |
+
x = corrupt(x, severity, corruption_number=corrupt_id)
|
60 |
+
x = Image.fromarray(x[:,:,::-1])
|
61 |
+
if x.size != org_size:
|
62 |
+
x = x.resize(org_size, Image.BILINEAR)
|
63 |
+
return x
|
64 |
+
|
65 |
+
def transform_with_fixed_severity(self, x, severity, corrupt_id=None):
|
66 |
+
if corrupt_id is None:
|
67 |
+
corrupt_id = np.random.choice(self.corrupt_ids)
|
68 |
+
else:
|
69 |
+
assert corrupt_id in self.corrupt_ids
|
70 |
+
assert severity > 0 and severity < 6
|
71 |
+
org_size = x.size
|
72 |
+
x = np.asarray(x.convert('RGB').resize((224, 224), Image.BILINEAR))[:,:,::-1]
|
73 |
+
x = corrupt(x, severity, corruption_number=corrupt_id)
|
74 |
+
x = Image.fromarray(x[:,:,::-1])
|
75 |
+
if x.size != org_size:
|
76 |
+
x = x.resize(org_size, Image.BILINEAR)
|
77 |
+
return x
|
78 |
+
|
79 |
+
def __repr__(self):
|
80 |
+
s = f'(severity={self.severity}, phase={self.phase}, p={self.p},ids={self.corrupt_ids})'
|
81 |
+
return self.__class__.__name__ + s
|
82 |
+
|
83 |
+
|
84 |
+
class NoiseResidual(object):
|
85 |
+
def __init__(self, k=16):
|
86 |
+
self.k = k
|
87 |
+
def __call__(self, x):
|
88 |
+
h, w = x.height, x.width
|
89 |
+
x1 = x.resize((w//self.k,h//self.k), Image.BILINEAR).resize((w, h), Image.BILINEAR)
|
90 |
+
x1 = np.abs(np.array(x).astype(np.float32) - np.array(x1).astype(np.float32))
|
91 |
+
x1 = (x1 - x1.min())/(x1.max() - x1.min() + np.finfo(np.float32).eps)
|
92 |
+
x1 = Image.fromarray((x1*255).astype(np.uint8))
|
93 |
+
return x1
|
94 |
+
def __repr__(self):
|
95 |
+
s = f'(k={self.k}'
|
96 |
+
return self.__class__.__name__ + s
|
97 |
+
|
98 |
+
|
99 |
+
def get_transforms(img_mean=[0.5, 0.5, 0.5], img_std=[0.5, 0.5, 0.5], rsize=256, csize=224, pertubation=True, dct=False, residual=False, max_c=19):
|
100 |
+
from torchvision import transforms
|
101 |
+
prep = transforms.Compose([
|
102 |
+
transforms.Resize(rsize),
|
103 |
+
transforms.RandomHorizontalFlip(),
|
104 |
+
transforms.RandomCrop(csize)])
|
105 |
+
if pertubation:
|
106 |
+
pertubation_train = RandomImagenetC(max_severity=5, phase='train', p=0.95,n=max_c)
|
107 |
+
pertubation_val = RandomImagenetC(max_severity=5, phase='train', p=1.0,n=max_c)
|
108 |
+
pertubation_test = RandomImagenetC(max_severity=5, phase='val', p=1.0,n=max_c)
|
109 |
+
else:
|
110 |
+
pertubation_train = pertubation_val = pertubation_test = IdentityAugment()
|
111 |
+
if dct:
|
112 |
+
from .image_tools import DCT
|
113 |
+
norm = [
|
114 |
+
DCT(),
|
115 |
+
transforms.ToTensor(),
|
116 |
+
transforms.Normalize(mean=img_mean, std=img_std)]
|
117 |
+
else:
|
118 |
+
norm = [
|
119 |
+
transforms.ToTensor(),
|
120 |
+
transforms.Normalize(mean=img_mean, std=img_std)]
|
121 |
+
if residual:
|
122 |
+
norm.insert(0, NoiseResidual())
|
123 |
+
|
124 |
+
preprocess = {
|
125 |
+
'train': [prep, pertubation_train, transforms.Compose(norm)],
|
126 |
+
|
127 |
+
'val': [prep, pertubation_val, transforms.Compose(norm)],
|
128 |
+
|
129 |
+
'test_unseen': [prep, pertubation_test, transforms.Compose(norm)],
|
130 |
+
|
131 |
+
'clean': transforms.Compose([transforms.Resize(csize)] + norm)
|
132 |
+
}
|
133 |
+
return preprocess
|
134 |
+
|
135 |
+
|
136 |
+
# ## example
|
137 |
+
# from PIL import Image
|
138 |
+
# import numpy as np
|
139 |
+
# import time
|
140 |
+
# from imagenet_c import corrupt, corruption_dict
|
141 |
+
# im = Image.open('/vol/research/tubui1/projects/gan_prov/gan_models/stargan2/test.jpg').convert('RGB').resize((224,224), Image.BILINEAR)
|
142 |
+
# im.save('original.jpg')
|
143 |
+
# im = np.array(im)[:,:,::-1] # BRG
|
144 |
+
# t = np.zeros(19)
|
145 |
+
# for i, key in enumerate(corruption_dict.keys()):
|
146 |
+
# begin = time.time()
|
147 |
+
# for j in range(10):
|
148 |
+
# out = corrupt(im, 5, corruption_number=i)
|
149 |
+
# end = time.time()
|
150 |
+
# t[i] = end-begin
|
151 |
+
# # Image.fromarray(out[:,:,::-1]).save(f'imc_{key}.jpg')
|
152 |
+
# print(f'{i} - {key}: {end-begin}')
|
153 |
+
|
154 |
+
# for i,k in enumerate(corruption_dict.keys()):
|
155 |
+
# print(i, k, t[i])
|
tools/base_lmdb.py
ADDED
@@ -0,0 +1,588 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Union
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import io
|
5 |
+
import lmdb
|
6 |
+
import pickle
|
7 |
+
import gzip
|
8 |
+
import bz2
|
9 |
+
import lzma
|
10 |
+
import shutil
|
11 |
+
from tqdm import tqdm
|
12 |
+
import pandas as pd
|
13 |
+
import numpy as np
|
14 |
+
from numpy import ndarray
|
15 |
+
import time
|
16 |
+
import torch
|
17 |
+
from torch import Tensor
|
18 |
+
from distutils.dir_util import copy_tree
|
19 |
+
from PIL import Image
|
20 |
+
from PIL import ImageFile
|
21 |
+
|
22 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
23 |
+
|
24 |
+
|
25 |
+
def _default_encode(data: Any, protocol: int) -> bytes:
|
26 |
+
return pickle.dumps(data, protocol=protocol)
|
27 |
+
|
28 |
+
|
29 |
+
def _ascii_encode(data: str) -> bytes:
|
30 |
+
return data.encode("ascii")
|
31 |
+
|
32 |
+
|
33 |
+
def _default_decode(data: bytes) -> Any:
|
34 |
+
return pickle.loads(data)
|
35 |
+
|
36 |
+
|
37 |
+
def _default_decompress(data: bytes) -> bytes:
|
38 |
+
return data
|
39 |
+
|
40 |
+
|
41 |
+
def _decompress(compression: Optional[str]):
|
42 |
+
if compression is None:
|
43 |
+
_decompress = _default_decompress
|
44 |
+
elif compression == "gzip":
|
45 |
+
_decompress = gzip.decompress
|
46 |
+
elif compression == "bz2":
|
47 |
+
_decompress = bz2.decompress
|
48 |
+
elif compression == "lzma":
|
49 |
+
_decompress = lzma.decompress
|
50 |
+
else:
|
51 |
+
raise ValueError(f"Unknown compression algorithm: {compression}")
|
52 |
+
|
53 |
+
return _decompress
|
54 |
+
|
55 |
+
|
56 |
+
class BaseLMDB(object):
|
57 |
+
_database = None
|
58 |
+
_protocol = None
|
59 |
+
_length = None
|
60 |
+
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
path: Union[str, Path],
|
64 |
+
readahead: bool = False,
|
65 |
+
pre_open: bool = False,
|
66 |
+
compression: Optional[str] = None
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Base class for LMDB-backed databases.
|
70 |
+
|
71 |
+
:param path: Path to the database.
|
72 |
+
:param readahead: Enables the filesystem readahead mechanism.
|
73 |
+
:param pre_open: If set to True, the first iterations will be faster, but it will raise error when doing multi-gpu training. If set to False, the database will open when you will retrieve the first item.
|
74 |
+
"""
|
75 |
+
if not isinstance(path, str):
|
76 |
+
path = str(path)
|
77 |
+
|
78 |
+
self.path = path
|
79 |
+
self.readahead = readahead
|
80 |
+
self.pre_open = pre_open
|
81 |
+
self._decompress = _decompress(compression)
|
82 |
+
self._has_fetched_an_item = False
|
83 |
+
|
84 |
+
@property
|
85 |
+
def database(self):
|
86 |
+
if self._database is None:
|
87 |
+
self._database = lmdb.open(
|
88 |
+
path=self.path,
|
89 |
+
readonly=True,
|
90 |
+
readahead=self.readahead,
|
91 |
+
max_spare_txns=256,
|
92 |
+
lock=False,
|
93 |
+
)
|
94 |
+
return self._database
|
95 |
+
|
96 |
+
@database.deleter
|
97 |
+
def database(self):
|
98 |
+
if self._database is not None:
|
99 |
+
self._database.close()
|
100 |
+
self._database = None
|
101 |
+
|
102 |
+
@property
|
103 |
+
def protocol(self):
|
104 |
+
"""
|
105 |
+
Read the pickle protocol contained in the database.
|
106 |
+
|
107 |
+
:return: The set of available keys.
|
108 |
+
"""
|
109 |
+
if self._protocol is None:
|
110 |
+
self._protocol = self._get(
|
111 |
+
item="protocol",
|
112 |
+
encode_key=_ascii_encode,
|
113 |
+
decompress_value=_default_decompress,
|
114 |
+
decode_value=_default_decode,
|
115 |
+
)
|
116 |
+
return self._protocol
|
117 |
+
|
118 |
+
@property
|
119 |
+
def keys(self):
|
120 |
+
"""
|
121 |
+
Read the keys contained in the database.
|
122 |
+
|
123 |
+
:return: The set of available keys.
|
124 |
+
"""
|
125 |
+
protocol = self.protocol
|
126 |
+
keys = self._get(
|
127 |
+
item="keys",
|
128 |
+
encode_key=lambda key: _default_encode(key, protocol=protocol),
|
129 |
+
decompress_value=_default_decompress,
|
130 |
+
decode_value=_default_decode,
|
131 |
+
)
|
132 |
+
return keys
|
133 |
+
|
134 |
+
def __len__(self):
|
135 |
+
"""
|
136 |
+
Returns the number of keys available in the database.
|
137 |
+
|
138 |
+
:return: The number of keys.
|
139 |
+
"""
|
140 |
+
if self._length is None:
|
141 |
+
self._length = len(self.keys)
|
142 |
+
return self._length
|
143 |
+
|
144 |
+
def __getitem__(self, item):
|
145 |
+
"""
|
146 |
+
Retrieves an item or a list of items from the database.
|
147 |
+
|
148 |
+
:param item: A key or a list of keys.
|
149 |
+
:return: A value or a list of values.
|
150 |
+
"""
|
151 |
+
self._has_fetched_an_item = True
|
152 |
+
if not isinstance(item, list):
|
153 |
+
item = self._get(
|
154 |
+
item=item,
|
155 |
+
encode_key=self._encode_key,
|
156 |
+
decompress_value=self._decompress_value,
|
157 |
+
decode_value=self._decode_value,
|
158 |
+
)
|
159 |
+
else:
|
160 |
+
item = self._gets(
|
161 |
+
items=item,
|
162 |
+
encode_keys=self._encode_keys,
|
163 |
+
decompress_values=self._decompress_values,
|
164 |
+
decode_values=self._decode_values,
|
165 |
+
)
|
166 |
+
return item
|
167 |
+
|
168 |
+
def _get(self, item, encode_key, decompress_value, decode_value):
|
169 |
+
"""
|
170 |
+
Instantiates a transaction and its associated cursor to fetch an item.
|
171 |
+
|
172 |
+
:param item: A key.
|
173 |
+
:param encode_key:
|
174 |
+
:param decode_value:
|
175 |
+
:return:
|
176 |
+
"""
|
177 |
+
with self.database.begin() as txn:
|
178 |
+
with txn.cursor() as cursor:
|
179 |
+
item = self._fetch(
|
180 |
+
cursor=cursor,
|
181 |
+
key=item,
|
182 |
+
encode_key=encode_key,
|
183 |
+
decompress_value=decompress_value,
|
184 |
+
decode_value=decode_value,
|
185 |
+
)
|
186 |
+
self._keep_database()
|
187 |
+
return item
|
188 |
+
|
189 |
+
def _gets(self, items, encode_keys, decompress_values, decode_values):
|
190 |
+
"""
|
191 |
+
Instantiates a transaction and its associated cursor to fetch a list of items.
|
192 |
+
|
193 |
+
:param items: A list of keys.
|
194 |
+
:param encode_keys:
|
195 |
+
:param decode_values:
|
196 |
+
:return:
|
197 |
+
"""
|
198 |
+
with self.database.begin() as txn:
|
199 |
+
with txn.cursor() as cursor:
|
200 |
+
items = self._fetchs(
|
201 |
+
cursor=cursor,
|
202 |
+
keys=items,
|
203 |
+
encode_keys=encode_keys,
|
204 |
+
decompress_values=decompress_values,
|
205 |
+
decode_values=decode_values,
|
206 |
+
)
|
207 |
+
self._keep_database()
|
208 |
+
return items
|
209 |
+
|
210 |
+
def _fetch(self, cursor, key, encode_key, decompress_value, decode_value):
|
211 |
+
"""
|
212 |
+
Retrieve a value given a key.
|
213 |
+
|
214 |
+
:param cursor:
|
215 |
+
:param key: A key.
|
216 |
+
:param encode_key:
|
217 |
+
:param decode_value:
|
218 |
+
:return: A value.
|
219 |
+
"""
|
220 |
+
key = encode_key(key)
|
221 |
+
value = cursor.get(key)
|
222 |
+
value = decompress_value(value)
|
223 |
+
value = decode_value(value)
|
224 |
+
return value
|
225 |
+
|
226 |
+
def _fetchs(self, cursor, keys, encode_keys, decompress_values, decode_values):
|
227 |
+
"""
|
228 |
+
Retrieve a list of values given a list of keys.
|
229 |
+
|
230 |
+
:param cursor:
|
231 |
+
:param keys: A list of keys.
|
232 |
+
:param encode_keys:
|
233 |
+
:param decode_values:
|
234 |
+
:return: A list of values.
|
235 |
+
"""
|
236 |
+
keys = encode_keys(keys)
|
237 |
+
_, values = list(zip(*cursor.getmulti(keys)))
|
238 |
+
values = decompress_values(values)
|
239 |
+
values = decode_values(values)
|
240 |
+
return values
|
241 |
+
|
242 |
+
def _encode_key(self, key: Any) -> bytes:
|
243 |
+
"""
|
244 |
+
Converts a key into a byte key.
|
245 |
+
|
246 |
+
:param key: A key.
|
247 |
+
:return: A byte key.
|
248 |
+
"""
|
249 |
+
return pickle.dumps(key, protocol=self.protocol)
|
250 |
+
|
251 |
+
def _encode_keys(self, keys: list) -> list:
|
252 |
+
"""
|
253 |
+
Converts keys into byte keys.
|
254 |
+
|
255 |
+
:param keys: A list of keys.
|
256 |
+
:return: A list of byte keys.
|
257 |
+
"""
|
258 |
+
return [self._encode_key(key=key) for key in keys]
|
259 |
+
|
260 |
+
def _decompress_value(self, value: bytes) -> bytes:
|
261 |
+
return self._decompress(value)
|
262 |
+
|
263 |
+
def _decompress_values(self, values: list) -> list:
|
264 |
+
return [self._decompress_value(value=value) for value in values]
|
265 |
+
|
266 |
+
def _decode_value(self, value: bytes) -> Any:
|
267 |
+
"""
|
268 |
+
Converts a byte value back into a value.
|
269 |
+
|
270 |
+
:param value: A byte value.
|
271 |
+
:return: A value
|
272 |
+
"""
|
273 |
+
return pickle.loads(value)
|
274 |
+
|
275 |
+
def _decode_values(self, values: list) -> list:
|
276 |
+
"""
|
277 |
+
Converts bytes values back into values.
|
278 |
+
|
279 |
+
:param values: A list of byte values.
|
280 |
+
:return: A list of values.
|
281 |
+
"""
|
282 |
+
return [self._decode_value(value=value) for value in values]
|
283 |
+
|
284 |
+
def _keep_database(self):
|
285 |
+
"""
|
286 |
+
Checks if the database must be deleted.
|
287 |
+
|
288 |
+
:return:
|
289 |
+
"""
|
290 |
+
if not self.pre_open and not self._has_fetched_an_item:
|
291 |
+
del self.database
|
292 |
+
|
293 |
+
def __iter__(self):
|
294 |
+
"""
|
295 |
+
Provides an iterator over the keys when iterating over the database.
|
296 |
+
|
297 |
+
:return: An iterator on the keys.
|
298 |
+
"""
|
299 |
+
return iter(self.keys)
|
300 |
+
|
301 |
+
def __del__(self):
|
302 |
+
"""
|
303 |
+
Closes the database properly.
|
304 |
+
"""
|
305 |
+
del self.database
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def write(data_lst, indir, outdir):
|
309 |
+
raise NotImplementedError
|
310 |
+
|
311 |
+
|
312 |
+
class PILlmdb(BaseLMDB):
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
lmdb_dir: Union[str, Path],
|
316 |
+
image_list: Union[str, Path, pd.DataFrame]=None,
|
317 |
+
index_key='id',
|
318 |
+
**kwargs
|
319 |
+
):
|
320 |
+
super().__init__(path=lmdb_dir, **kwargs)
|
321 |
+
if image_list is None:
|
322 |
+
self.ids = list(range(len(self.keys)))
|
323 |
+
self.labels = list(range(len(self.ids)))
|
324 |
+
else:
|
325 |
+
df = pd.read_csv(str(image_list))
|
326 |
+
assert index_key in df, f'[PILlmdb] Error! {image_list} must have id keys.'
|
327 |
+
self.ids = df[index_key].tolist()
|
328 |
+
assert max(self.ids) < len(self.keys)
|
329 |
+
if 'label' in df:
|
330 |
+
self.labels = df['label'].tolist()
|
331 |
+
else: # all numeric keys other than 'id' are labels
|
332 |
+
keys = [key for key in df if (key!=index_key and type(df[key][0]) in [int, np.int64])]
|
333 |
+
# df = df.drop('id', axis=1)
|
334 |
+
self.labels = df[keys].to_numpy()
|
335 |
+
self._length = len(self.ids)
|
336 |
+
|
337 |
+
def __len__(self):
|
338 |
+
return self._length
|
339 |
+
|
340 |
+
def __iter__(self):
|
341 |
+
return iter([self.keys[i] for i in self.ids])
|
342 |
+
|
343 |
+
def __getitem__(self, index):
|
344 |
+
key = self.keys[self.ids[index]]
|
345 |
+
return super().__getitem__(key)
|
346 |
+
|
347 |
+
def set_ids(self, ids):
|
348 |
+
self.ids = [self.ids[i] for i in ids]
|
349 |
+
self.labels = [self.labels[i] for i in ids]
|
350 |
+
self._length = len(self.ids)
|
351 |
+
|
352 |
+
def _decode_value(self, value: bytes):
|
353 |
+
"""
|
354 |
+
Converts a byte image back into a PIL Image.
|
355 |
+
|
356 |
+
:param value: A byte image.
|
357 |
+
:return: A PIL Image image.
|
358 |
+
"""
|
359 |
+
return Image.open(io.BytesIO(value))
|
360 |
+
|
361 |
+
@staticmethod
|
362 |
+
def write(indir, outdir, data_lst=None, transform=None):
|
363 |
+
"""
|
364 |
+
create lmdb given data directory and list of image paths; or an iterator
|
365 |
+
:param data_lst None or csv file containing 'path' key to store relative paths to the images
|
366 |
+
:param indir root directory of the images
|
367 |
+
:param outdir output lmdb, data.mdb and lock.mdb will be written here
|
368 |
+
"""
|
369 |
+
|
370 |
+
outdir = Path(outdir)
|
371 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
372 |
+
tmp_dir = Path("/tmp") / f"TEMP_{time.time()}"
|
373 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
374 |
+
dtype = {'str': False, 'pil': False}
|
375 |
+
if isinstance(indir, str) or isinstance(indir, Path):
|
376 |
+
indir = Path(indir)
|
377 |
+
if data_lst is None: # grab all images in this dir
|
378 |
+
lst = list(indir.glob('**/*.jpg')) + list(indir.glob('**/*.png'))
|
379 |
+
else:
|
380 |
+
lst = pd.read_csv(data_lst)['path'].tolist()
|
381 |
+
lst = [indir/p for p in lst]
|
382 |
+
assert len(lst) > 0, f'Couldnt find any image in {indir} (Support only .jpg and .png) or list (must have path field).'
|
383 |
+
n = len(lst)
|
384 |
+
dtype['str'] = True
|
385 |
+
else: # iterator
|
386 |
+
n = len(indir)
|
387 |
+
lst = iter(indir)
|
388 |
+
dtype['pil'] = True
|
389 |
+
|
390 |
+
with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env:
|
391 |
+
# Add the protocol to the database.
|
392 |
+
with env.begin(write=True) as txn:
|
393 |
+
key = "protocol".encode("ascii")
|
394 |
+
value = pickle.dumps(pickle.DEFAULT_PROTOCOL)
|
395 |
+
txn.put(key=key, value=value, dupdata=False)
|
396 |
+
# Add the keys to the database.
|
397 |
+
with env.begin(write=True) as txn:
|
398 |
+
key = pickle.dumps("keys")
|
399 |
+
value = pickle.dumps(list(range(n)))
|
400 |
+
txn.put(key=key, value=value, dupdata=False)
|
401 |
+
# Add the images to the database.
|
402 |
+
for key, value in tqdm(enumerate(lst), total=n, miniters=n//100, mininterval=300):
|
403 |
+
with env.begin(write=True) as txn:
|
404 |
+
key = pickle.dumps(key)
|
405 |
+
if dtype['str']:
|
406 |
+
with value.open("rb") as file:
|
407 |
+
byteimg = file.read()
|
408 |
+
else: # PIL
|
409 |
+
data = io.BytesIO()
|
410 |
+
value.save(data, 'png')
|
411 |
+
byteimg = data.getvalue()
|
412 |
+
|
413 |
+
if transform is not None:
|
414 |
+
im = Image.open(io.BytesIO(byteimg))
|
415 |
+
im = transform(im)
|
416 |
+
data = io.BytesIO()
|
417 |
+
im.save(data, 'png')
|
418 |
+
byteimg = data.getvalue()
|
419 |
+
txn.put(key=key, value=byteimg, dupdata=False)
|
420 |
+
|
421 |
+
# Move the database to its destination.
|
422 |
+
copy_tree(str(tmp_dir), str(outdir))
|
423 |
+
shutil.rmtree(str(tmp_dir))
|
424 |
+
|
425 |
+
|
426 |
+
|
427 |
+
class MaskDatabase(PILlmdb):
|
428 |
+
def _decode_value(self, value: bytes):
|
429 |
+
"""
|
430 |
+
Converts a byte image back into a PIL Image.
|
431 |
+
|
432 |
+
:param value: A byte image.
|
433 |
+
:return: A PIL Image image.
|
434 |
+
"""
|
435 |
+
return Image.open(io.BytesIO(value)).convert("1")
|
436 |
+
|
437 |
+
|
438 |
+
class LabelDatabase(BaseLMDB):
|
439 |
+
pass
|
440 |
+
|
441 |
+
|
442 |
+
class ArrayDatabase(BaseLMDB):
|
443 |
+
_dtype = None
|
444 |
+
_shape = None
|
445 |
+
|
446 |
+
def __init__(
|
447 |
+
self,
|
448 |
+
lmdb_dir: Union[str, Path],
|
449 |
+
image_list: Union[str, Path, pd.DataFrame]=None,
|
450 |
+
**kwargs
|
451 |
+
):
|
452 |
+
super().__init__(path=lmdb_dir, **kwargs)
|
453 |
+
if image_list is None:
|
454 |
+
self.ids = list(range(len(self.keys)))
|
455 |
+
self.labels = list(range(len(self.ids)))
|
456 |
+
else:
|
457 |
+
df = pd.read_csv(str(image_list))
|
458 |
+
assert 'id' in df, f'[ArrayDatabase] Error! {image_list} must have id keys.'
|
459 |
+
self.ids = df['id'].tolist()
|
460 |
+
assert max(self.ids) < len(self.keys)
|
461 |
+
if 'label' in df:
|
462 |
+
self.labels = df['label'].tolist()
|
463 |
+
else: # all numeric keys other than 'id' are labels
|
464 |
+
keys = [key for key in df if (key!='id' and type(df[key][0]) in [int, np.int64])]
|
465 |
+
# df = df.drop('id', axis=1)
|
466 |
+
self.labels = df[keys].to_numpy()
|
467 |
+
self._length = len(self.ids)
|
468 |
+
|
469 |
+
def set_ids(self, ids):
|
470 |
+
self.ids = [self.ids[i] for i in ids]
|
471 |
+
self.labels = [self.labels[i] for i in ids]
|
472 |
+
self._length = len(self.ids)
|
473 |
+
|
474 |
+
def __len__(self):
|
475 |
+
return self._length
|
476 |
+
|
477 |
+
def __iter__(self):
|
478 |
+
return iter([self.keys[i] for i in self.ids])
|
479 |
+
|
480 |
+
def __getitem__(self, index):
|
481 |
+
key = self.keys[self.ids[index]]
|
482 |
+
return super().__getitem__(key)
|
483 |
+
|
484 |
+
@property
|
485 |
+
def dtype(self):
|
486 |
+
if self._dtype is None:
|
487 |
+
protocol = self.protocol
|
488 |
+
self._dtype = self._get(
|
489 |
+
item="dtype",
|
490 |
+
encode_key=lambda key: _default_encode(key, protocol=protocol),
|
491 |
+
decompress_value=_default_decompress,
|
492 |
+
decode_value=_default_decode,
|
493 |
+
)
|
494 |
+
return self._dtype
|
495 |
+
|
496 |
+
@property
|
497 |
+
def shape(self):
|
498 |
+
if self._shape is None:
|
499 |
+
protocol = self.protocol
|
500 |
+
self._shape = self._get(
|
501 |
+
item="shape",
|
502 |
+
encode_key=lambda key: _default_encode(key, protocol=protocol),
|
503 |
+
decompress_value=_default_decompress,
|
504 |
+
decode_value=_default_decode,
|
505 |
+
)
|
506 |
+
return self._shape
|
507 |
+
|
508 |
+
def _decode_value(self, value: bytes) -> ndarray:
|
509 |
+
value = super()._decode_value(value)
|
510 |
+
return np.frombuffer(value, dtype=self.dtype).reshape(self.shape)
|
511 |
+
|
512 |
+
def _decode_values(self, values: list) -> ndarray:
|
513 |
+
shape = (len(values),) + self.shape
|
514 |
+
return np.frombuffer(b"".join(values), dtype=self.dtype).reshape(shape)
|
515 |
+
|
516 |
+
@staticmethod
|
517 |
+
def write(diter, outdir):
|
518 |
+
"""
|
519 |
+
diter is an iterator that has __len__ method
|
520 |
+
class Myiter():
|
521 |
+
def __init__(self, data):
|
522 |
+
self.data = data
|
523 |
+
def __iter__(self):
|
524 |
+
self.counter = 0
|
525 |
+
return self
|
526 |
+
def __len__(self):
|
527 |
+
return len(self.data)
|
528 |
+
def __next__(self):
|
529 |
+
if self.counter < len(self):
|
530 |
+
out = self.data[self.counter]
|
531 |
+
self.counter+=1
|
532 |
+
return out
|
533 |
+
else:
|
534 |
+
raise StopIteration
|
535 |
+
a = iter(Myiter([1,2,3]))
|
536 |
+
for i in a:
|
537 |
+
print(i)
|
538 |
+
"""
|
539 |
+
outdir = Path(outdir)
|
540 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
541 |
+
tmp_dir = Path("/tmp") / f"TEMP_{time.time()}"
|
542 |
+
tmp_dir.mkdir(parents=True, exist_ok=True)
|
543 |
+
# Create the database.
|
544 |
+
n = len(diter)
|
545 |
+
with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env:
|
546 |
+
# Add the protocol to the database.
|
547 |
+
with env.begin(write=True) as txn:
|
548 |
+
key = "protocol".encode("ascii")
|
549 |
+
value = pickle.dumps(pickle.DEFAULT_PROTOCOL)
|
550 |
+
txn.put(key=key, value=value, dupdata=False)
|
551 |
+
# Add the keys to the database.
|
552 |
+
with env.begin(write=True) as txn:
|
553 |
+
key = pickle.dumps("keys")
|
554 |
+
value = pickle.dumps(list(range(n)))
|
555 |
+
txn.put(key=key, value=value, dupdata=False)
|
556 |
+
# Extract the shape and dtype of the values.
|
557 |
+
value = next(iter(diter))
|
558 |
+
shape = value.shape
|
559 |
+
dtype = value.dtype
|
560 |
+
# Add the shape to the database.
|
561 |
+
with env.begin(write=True) as txn:
|
562 |
+
key = pickle.dumps("shape")
|
563 |
+
value = pickle.dumps(shape)
|
564 |
+
txn.put(key=key, value=value, dupdata=False)
|
565 |
+
# Add the dtype to the database.
|
566 |
+
with env.begin(write=True) as txn:
|
567 |
+
key = pickle.dumps("dtype")
|
568 |
+
value = pickle.dumps(dtype)
|
569 |
+
txn.put(key=key, value=value, dupdata=False)
|
570 |
+
# Add the values to the database.
|
571 |
+
with env.begin(write=True) as txn:
|
572 |
+
for key, value in tqdm(enumerate(iter(diter)), total=n, miniters=n//100, mininterval=300):
|
573 |
+
key = pickle.dumps(key)
|
574 |
+
value = pickle.dumps(value)
|
575 |
+
txn.put(key=key, value=value, dupdata=False)
|
576 |
+
|
577 |
+
# Move the database to its destination.
|
578 |
+
copy_tree(str(tmp_dir), str(outdir))
|
579 |
+
shutil.rmtree(str(tmp_dir))
|
580 |
+
|
581 |
+
|
582 |
+
|
583 |
+
class TensorDatabase(ArrayDatabase):
|
584 |
+
def _decode_value(self, value: bytes) -> Tensor:
|
585 |
+
return torch.from_numpy(super(TensorDatabase, self)._decode_value(value))
|
586 |
+
|
587 |
+
def _decode_values(self, values: list) -> Tensor:
|
588 |
+
return torch.from_numpy(super(TensorDatabase, self)._decode_values(values))
|
tools/ecc.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bchlib
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Tuple
|
4 |
+
import random
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
class RSC(object):
|
8 |
+
def __init__(self, data_bytes=16, ecc_bytes=4, verbose=False, **kwargs):
|
9 |
+
from reedsolo import RSCodec
|
10 |
+
self.rs = RSCodec(ecc_bytes)
|
11 |
+
if verbose:
|
12 |
+
print(f'Reed-Solomon ECC len: {ecc_bytes*8} bits')
|
13 |
+
self.data_len = data_bytes
|
14 |
+
self.dlen = data_bytes * 8 # data length in bits
|
15 |
+
self.ecc_len = ecc_bytes * 8 # ecc length in bits
|
16 |
+
|
17 |
+
def get_total_len(self):
|
18 |
+
return self.dlen + self.ecc_len
|
19 |
+
|
20 |
+
def encode_text(self, text: List[str]):
|
21 |
+
return np.array([self._encode_text(t) for t in text])
|
22 |
+
|
23 |
+
def _encode_text(self, text: str):
|
24 |
+
text = text + ' ' * (self.dlen // 8 - len(text))
|
25 |
+
out = self.rs.encode(text.encode('utf-8')) # bytearray
|
26 |
+
out = ''.join(format(x, '08b') for x in out) # bit string
|
27 |
+
out = np.array([int(x) for x in out], dtype=np.float32)
|
28 |
+
return out
|
29 |
+
|
30 |
+
def decode_text(self, data: np.array):
|
31 |
+
assert len(data.shape)==2
|
32 |
+
return [self._decode_text(d) for d in data]
|
33 |
+
|
34 |
+
def _decode_text(self, data: np.array):
|
35 |
+
assert len(data.shape)==1
|
36 |
+
data = ''.join([str(int(bit)) for bit in data])
|
37 |
+
data = bytes(int(data[i: i + 8], 2) for i in range(0, len(data), 8))
|
38 |
+
data = bytearray(data)
|
39 |
+
try:
|
40 |
+
data = self.rs.decode(data)[0]
|
41 |
+
data = data.decode('utf-8').strip()
|
42 |
+
except:
|
43 |
+
print('Error: Decode failed')
|
44 |
+
data = get_random_unicode(self.get_total_len()//8)
|
45 |
+
|
46 |
+
return data
|
47 |
+
|
48 |
+
def get_random_unicode(length):
|
49 |
+
# Update this to include code point ranges to be sampled
|
50 |
+
include_ranges = [
|
51 |
+
( 0x0021, 0x0021 ),
|
52 |
+
( 0x0023, 0x0026 ),
|
53 |
+
( 0x0028, 0x007E ),
|
54 |
+
( 0x00A1, 0x00AC ),
|
55 |
+
( 0x00AE, 0x00FF ),
|
56 |
+
( 0x0100, 0x017F ),
|
57 |
+
( 0x0180, 0x024F ),
|
58 |
+
( 0x2C60, 0x2C7F ),
|
59 |
+
( 0x16A0, 0x16F0 ),
|
60 |
+
( 0x0370, 0x0377 ),
|
61 |
+
( 0x037A, 0x037E ),
|
62 |
+
( 0x0384, 0x038A ),
|
63 |
+
( 0x038C, 0x038C ),
|
64 |
+
]
|
65 |
+
alphabet = [
|
66 |
+
chr(code_point) for current_range in include_ranges
|
67 |
+
for code_point in range(current_range[0], current_range[1] + 1)
|
68 |
+
]
|
69 |
+
return ''.join(random.choice(alphabet) for i in range(length))
|
70 |
+
|
71 |
+
|
72 |
+
class BCH(object):
|
73 |
+
def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, payload_len=100, verbose=True,**kwargs):
|
74 |
+
self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
|
75 |
+
self.payload_len = payload_len # in bits
|
76 |
+
self.data_len = (self.payload_len - self.bch.ecc_bytes*8)//7 # in ascii characters
|
77 |
+
assert self.data_len*7+self.bch.ecc_bytes*8 <= self.bch.n, f'Error! BCH with poly {BCH_POLYNOMIAL} and bits {BCH_BITS} can only encode max {self.bch.n//8} bytes of total payload'
|
78 |
+
if verbose:
|
79 |
+
print(f'BCH: POLYNOMIAL={BCH_POLYNOMIAL}, protected bits={BCH_BITS}, payload_len={payload_len} bits, data_len={self.data_len*7} bits ({self.data_len} ascii chars), ecc len={self.bch.ecc_bytes*8} bits')
|
80 |
+
|
81 |
+
def get_total_len(self):
|
82 |
+
return self.payload_len
|
83 |
+
|
84 |
+
def encode_text(self, text: List[str]):
|
85 |
+
return np.array([self._encode_text(t) for t in text])
|
86 |
+
|
87 |
+
def _encode_text(self, text: str):
|
88 |
+
text = text + ' ' * (self.data_len - len(text))
|
89 |
+
# data = text.encode('utf-8') # bytearray
|
90 |
+
data = encode_text_ascii(text) # bytearray
|
91 |
+
ecc = self.bch.encode(data) # bytearray
|
92 |
+
packet = data + ecc # payload in bytearray
|
93 |
+
packet = ''.join(format(x, '08b') for x in packet)
|
94 |
+
packet = [int(x) for x in packet]
|
95 |
+
packet.extend([0]*(self.payload_len - len(packet)))
|
96 |
+
packet = np.array(packet, dtype=np.float32)
|
97 |
+
return packet
|
98 |
+
|
99 |
+
def decode_text(self, data: np.array):
|
100 |
+
assert len(data.shape)==2
|
101 |
+
return [self._decode_text(d) for d in data]
|
102 |
+
|
103 |
+
def _decode_text(self, packet: np.array):
|
104 |
+
assert len(packet.shape)==1
|
105 |
+
packet = ''.join([str(int(bit)) for bit in packet]) # bit string
|
106 |
+
packet = packet[:(len(packet)//8*8)] # trim to multiple of 8 bits
|
107 |
+
packet = bytes(int(packet[i: i + 8], 2) for i in range(0, len(packet), 8))
|
108 |
+
packet = bytearray(packet)
|
109 |
+
# assert len(packet) == self.data_len + self.bch.ecc_bytes
|
110 |
+
data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
|
111 |
+
data0 = decode_text_ascii(deepcopy(data)).strip()
|
112 |
+
bitflips = self.bch.decode_inplace(data, ecc)
|
113 |
+
if bitflips == -1: # error, return random text
|
114 |
+
data = data0
|
115 |
+
else:
|
116 |
+
# data = data.decode('utf-8').strip()
|
117 |
+
data = decode_text_ascii(data).strip()
|
118 |
+
return data
|
119 |
+
|
120 |
+
|
121 |
+
def encode_text_ascii(text: str):
|
122 |
+
# encode text to 7-bit ascii
|
123 |
+
# input: text, str
|
124 |
+
# output: encoded text, bytearray
|
125 |
+
text_int7 = [ord(t) & 127 for t in text]
|
126 |
+
text_bitstr = ''.join(format(t,'07b') for t in text_int7)
|
127 |
+
if len(text_bitstr) % 8 != 0:
|
128 |
+
text_bitstr = '0'*(8-len(text_bitstr)%8) + text_bitstr # pad to multiple of 8
|
129 |
+
text_int8 = [int(text_bitstr[i:i+8], 2) for i in range(0, len(text_bitstr), 8)]
|
130 |
+
return bytearray(text_int8)
|
131 |
+
|
132 |
+
|
133 |
+
def decode_text_ascii(text: bytearray):
|
134 |
+
# decode text from 7-bit ascii
|
135 |
+
# input: text, bytearray
|
136 |
+
# output: decoded text, str
|
137 |
+
text_bitstr = ''.join(format(t,'08b') for t in text) # bit string
|
138 |
+
pad = len(text_bitstr) % 7
|
139 |
+
if pad != 0: # has padding, remove
|
140 |
+
text_bitstr = text_bitstr[pad:]
|
141 |
+
text_int7 = [int(text_bitstr[i:i+7], 2) for i in range(0, len(text_bitstr), 7)]
|
142 |
+
text_bytes = bytes(text_int7)
|
143 |
+
return text_bytes.decode('utf-8')
|
144 |
+
|
145 |
+
|
146 |
+
class ECC(object):
|
147 |
+
def __init__(self, BCH_POLYNOMIAL = 137, BCH_BITS = 5, **kwargs):
|
148 |
+
self.bch = bchlib.BCH(BCH_POLYNOMIAL, BCH_BITS)
|
149 |
+
|
150 |
+
def get_total_len(self):
|
151 |
+
return 100
|
152 |
+
|
153 |
+
def _encode(self, x):
|
154 |
+
# x: 56 bits, {0, 1}, np.array
|
155 |
+
# return: 100 bits, {0, 1}, np.array
|
156 |
+
dlen = len(x)
|
157 |
+
data_str = ''.join(str(x) for x in x.astype(int))
|
158 |
+
packet = bytes(int(data_str[i: i + 8], 2) for i in range(0, dlen, 8))
|
159 |
+
packet = bytearray(packet)
|
160 |
+
ecc = self.bch.encode(packet)
|
161 |
+
packet = packet + ecc # 96 bits
|
162 |
+
packet = ''.join(format(x, '08b') for x in packet)
|
163 |
+
packet = [int(x) for x in packet]
|
164 |
+
packet.extend([0, 0, 0, 0])
|
165 |
+
packet = np.array(packet, dtype=np.float32) # 100
|
166 |
+
return packet
|
167 |
+
|
168 |
+
def _decode(self, x):
|
169 |
+
# x: 100 bits, {0, 1}, np.array
|
170 |
+
# return: 56 bits, {0, 1}, np.array
|
171 |
+
packet_binary = "".join([str(int(bit)) for bit in x])
|
172 |
+
packet = bytes(int(packet_binary[i: i + 8], 2) for i in range(0, len(packet_binary), 8))
|
173 |
+
packet = bytearray(packet)
|
174 |
+
|
175 |
+
data, ecc = packet[:-self.bch.ecc_bytes], packet[-self.bch.ecc_bytes:]
|
176 |
+
bitflips = self.bch.decode_inplace(data, ecc)
|
177 |
+
if bitflips == -1: # error, return random data
|
178 |
+
data = np.random.binomial(1, .5, 56)
|
179 |
+
else:
|
180 |
+
data = ''.join(format(x, '08b') for x in data)
|
181 |
+
data = np.array([int(x) for x in data], dtype=np.float32)
|
182 |
+
return data # 56 bits
|
183 |
+
|
184 |
+
def _generate(self):
|
185 |
+
dlen = 56
|
186 |
+
data= np.random.binomial(1, .5, dlen)
|
187 |
+
packet = self._encode(data)
|
188 |
+
return packet, data
|
189 |
+
|
190 |
+
def generate(self, nsamples=1):
|
191 |
+
# generate random 56 bit secret
|
192 |
+
data = [self._generate() for _ in range(nsamples)]
|
193 |
+
data = (np.array([d[0] for d in data]), np.array([d[1] for d in data]))
|
194 |
+
return data # data with ecc, data org
|
195 |
+
|
196 |
+
def _to_text(self, data):
|
197 |
+
# data: {0, 1}, np.array
|
198 |
+
# return: str
|
199 |
+
data = ''.join([str(int(bit)) for bit in data])
|
200 |
+
all_bytes = [ data[i: i+8] for i in range(0, len(data), 8) ]
|
201 |
+
text = ''.join([chr(int(byte, 2)) for byte in all_bytes])
|
202 |
+
return text.strip()
|
203 |
+
|
204 |
+
def _to_binary(self, s):
|
205 |
+
if isinstance(s, str):
|
206 |
+
out = ''.join([ format(ord(i), "08b") for i in s ])
|
207 |
+
elif isinstance(s, bytes):
|
208 |
+
out = ''.join([ format(i, "08b") for i in s ])
|
209 |
+
elif isinstance(s, np.ndarray) and s.dtype is np.dtype(bool):
|
210 |
+
out = ''.join([chr(int(i)) for i in s])
|
211 |
+
elif isinstance(s, int) or isinstance(s, np.uint8):
|
212 |
+
out = format(s, "08b")
|
213 |
+
elif isinstance(s, np.ndarray):
|
214 |
+
out = [ format(i, "08b") for i in s ]
|
215 |
+
else:
|
216 |
+
raise TypeError("Type not supported.")
|
217 |
+
|
218 |
+
return np.array([float(i) for i in out], dtype=np.float32)
|
219 |
+
|
220 |
+
def _encode_text(self, s):
|
221 |
+
s = s + ' '*(7-len(s)) # 7 chars
|
222 |
+
s = self._to_binary(s) # 56 bits
|
223 |
+
packet = self._encode(s) # 100 bits
|
224 |
+
return packet, s
|
225 |
+
|
226 |
+
def encode_text(self, secret_list, return_pre_ecc=False):
|
227 |
+
"""encode secret with BCH ECC.
|
228 |
+
Input: secret (list of strings)
|
229 |
+
Output: secret (np array) with shape (B, 100) type float23, val {0,1}"""
|
230 |
+
assert np.all(np.array([len(s) for s in secret_list]) <= 7), 'Error! all strings must be less than 7 characters'
|
231 |
+
secret_list = [self._encode_text(s) for s in secret_list]
|
232 |
+
ecc = np.array([s[0] for s in secret_list], dtype=np.float32)
|
233 |
+
if return_pre_ecc:
|
234 |
+
return ecc, np.array([s[1] for s in secret_list], dtype=np.float32)
|
235 |
+
return ecc
|
236 |
+
|
237 |
+
def decode_text(self, data):
|
238 |
+
"""Decode secret with BCH ECC and convert to string.
|
239 |
+
Input: secret (torch.tensor) with shape (B, 100) type bool
|
240 |
+
Output: secret (B, 56)"""
|
241 |
+
data = self.decode(data)
|
242 |
+
data = [self._to_text(d) for d in data]
|
243 |
+
return data
|
244 |
+
|
245 |
+
def decode(self, data):
|
246 |
+
"""Decode secret with BCH ECC and convert to string.
|
247 |
+
Input: secret (torch.tensor) with shape (B, 100) type bool
|
248 |
+
Output: secret (B, 56)"""
|
249 |
+
data = data[:, :96]
|
250 |
+
data = [self._decode(d) for d in data]
|
251 |
+
return np.array(data)
|
252 |
+
|
253 |
+
def test_ecc():
|
254 |
+
ecc = ECC()
|
255 |
+
batch_size = 10
|
256 |
+
secret_ecc, secret_org = ecc.generate(batch_size) # 10x100 ecc secret, 10x56 org secret
|
257 |
+
# modify secret_ecc
|
258 |
+
secret_pred = secret_ecc.copy()
|
259 |
+
secret_pred[:,3:6] = 1 - secret_pred[:,3:6]
|
260 |
+
# pass secret_ecc to model and get predicted as secret_pred
|
261 |
+
secret_pred_org = ecc.decode(secret_pred) # 10x56
|
262 |
+
assert np.all(secret_pred_org == secret_org) # 10
|
263 |
+
|
264 |
+
|
265 |
+
def test_bch():
|
266 |
+
# test 100 bit
|
267 |
+
def check(text, poly, k, l):
|
268 |
+
bch = BCH(poly, k, l)
|
269 |
+
# text = 'secrets'
|
270 |
+
encode = bch.encode_text([text])
|
271 |
+
for ind in np.random.choice(l, k):
|
272 |
+
encode[0, ind] = 1 - encode[0, ind]
|
273 |
+
text_recon = bch.decode_text(encode)[0]
|
274 |
+
assert text==text_recon
|
275 |
+
|
276 |
+
check('secrets', 137, 5, 100)
|
277 |
+
check('some secret', 285, 10, 160)
|
278 |
+
|
279 |
+
if __name__ == '__main__':
|
280 |
+
test_ecc()
|
281 |
+
test_bch()
|
tools/eval_metrics.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import skimage.metrics
|
4 |
+
import lpips
|
5 |
+
from PIL import Image
|
6 |
+
from .sifid import SIFID
|
7 |
+
|
8 |
+
|
9 |
+
def resize_array(x, size=256):
|
10 |
+
"""
|
11 |
+
Resize image array to given size.
|
12 |
+
Args:
|
13 |
+
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
|
14 |
+
size (int): Size of output image.
|
15 |
+
Returns:
|
16 |
+
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
|
17 |
+
"""
|
18 |
+
if x.shape[1] != size:
|
19 |
+
x = [Image.fromarray(x[i]).resize((size, size), resample=Image.BILINEAR) for i in range(x.shape[0])]
|
20 |
+
x = np.array([np.array(i) for i in x])
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
def resize_tensor(x, size=256):
|
25 |
+
"""
|
26 |
+
Resize image tensor to given size.
|
27 |
+
Args:
|
28 |
+
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
29 |
+
size (int): Size of output image.
|
30 |
+
Returns:
|
31 |
+
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
32 |
+
"""
|
33 |
+
if x.shape[2] != size:
|
34 |
+
x = torch.nn.functional.interpolate(x, size=(size, size), mode='bilinear', align_corners=False)
|
35 |
+
return x
|
36 |
+
|
37 |
+
|
38 |
+
def normalise(x):
|
39 |
+
"""
|
40 |
+
Normalise image array to range [-1, 1] and tensor.
|
41 |
+
Args:
|
42 |
+
x (np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
|
43 |
+
Returns:
|
44 |
+
(torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
45 |
+
"""
|
46 |
+
x = x.astype(np.float32)
|
47 |
+
x = x / 255
|
48 |
+
x = (x - 0.5) / 0.5
|
49 |
+
x = torch.from_numpy(x)
|
50 |
+
x = x.permute(0, 3, 1, 2)
|
51 |
+
return x
|
52 |
+
|
53 |
+
|
54 |
+
def unormalise(x, vrange=[-1, 1]):
|
55 |
+
"""
|
56 |
+
Unormalise image tensor to range [0, 255] and RGB array.
|
57 |
+
Args:
|
58 |
+
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
59 |
+
Returns:
|
60 |
+
(np.ndarray): Image array of shape (N, H, W, C) in range [0, 255].
|
61 |
+
"""
|
62 |
+
x = (x - vrange[0])/(vrange[1] - vrange[0])
|
63 |
+
x = x * 255
|
64 |
+
x = x.permute(0, 2, 3, 1)
|
65 |
+
x = x.cpu().numpy().astype(np.uint8)
|
66 |
+
return x
|
67 |
+
|
68 |
+
|
69 |
+
def compute_mse(x, y):
|
70 |
+
"""
|
71 |
+
Compute mean squared error between two image arrays.
|
72 |
+
Args:
|
73 |
+
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
74 |
+
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
75 |
+
Returns:
|
76 |
+
(1darray): Mean squared error.
|
77 |
+
"""
|
78 |
+
return np.square(x - y).reshape(x.shape[0], -1).mean(axis=1)
|
79 |
+
|
80 |
+
|
81 |
+
def compute_psnr(x, y):
|
82 |
+
"""
|
83 |
+
Compute peak signal-to-noise ratio between two images.
|
84 |
+
Args:
|
85 |
+
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
86 |
+
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
87 |
+
Returns:
|
88 |
+
(float): Peak signal-to-noise ratio.
|
89 |
+
"""
|
90 |
+
return 10 * np.log10(255 ** 2 / compute_mse(x, y))
|
91 |
+
|
92 |
+
|
93 |
+
def compute_ssim(x, y):
|
94 |
+
"""
|
95 |
+
Compute structural similarity index between two images.
|
96 |
+
Args:
|
97 |
+
x (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
98 |
+
y (np.ndarray): Image of shape (N, H, W, C) in range [0, 255].
|
99 |
+
Returns:
|
100 |
+
(float): Structural similarity index.
|
101 |
+
"""
|
102 |
+
return np.array([skimage.metrics.structural_similarity(xi, yi, channel_axis=2, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=255) for xi, yi in zip(x, y)])
|
103 |
+
|
104 |
+
|
105 |
+
def compute_lpips(x, y, net='alex'):
|
106 |
+
"""
|
107 |
+
Compute LPIPS between two images.
|
108 |
+
Args:
|
109 |
+
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
110 |
+
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
111 |
+
Returns:
|
112 |
+
(float): LPIPS.
|
113 |
+
"""
|
114 |
+
lpips_fn = lpips.LPIPS(net=net, verbose=False).cuda() if isinstance(net, str) else net
|
115 |
+
x, y = x.cuda(), y.cuda()
|
116 |
+
return lpips_fn(x, y).detach().cpu().numpy().squeeze()
|
117 |
+
|
118 |
+
|
119 |
+
def compute_sifid(x, y, net=None):
|
120 |
+
"""
|
121 |
+
Compute SIFID between two images.
|
122 |
+
Args:
|
123 |
+
x (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
124 |
+
y (torch.Tensor): Image tensor of shape (N, C, H, W) in range [-1, 1].
|
125 |
+
Returns:
|
126 |
+
(float): SIFID.
|
127 |
+
"""
|
128 |
+
fn = SIFID() if net is None else net
|
129 |
+
out = [fn(xi, yi) for xi, yi in zip(x, y)]
|
130 |
+
return np.array(out)
|
tools/fid.py
ADDED
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
37 |
+
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import torchvision.transforms as TF
|
41 |
+
from PIL import Image
|
42 |
+
from scipy import linalg
|
43 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
44 |
+
import torch.nn as nn
|
45 |
+
import torch.nn.functional as F
|
46 |
+
import torchvision
|
47 |
+
|
48 |
+
try:
|
49 |
+
from tqdm import tqdm
|
50 |
+
except ImportError:
|
51 |
+
# If tqdm is not available, provide a mock version of it
|
52 |
+
def tqdm(x):
|
53 |
+
return x
|
54 |
+
|
55 |
+
|
56 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
57 |
+
'tif', 'tiff', 'webp'}
|
58 |
+
|
59 |
+
|
60 |
+
try:
|
61 |
+
from torchvision.models.utils import load_state_dict_from_url
|
62 |
+
except ImportError:
|
63 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
64 |
+
|
65 |
+
# Inception weights ported to Pytorch from
|
66 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
67 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
68 |
+
|
69 |
+
|
70 |
+
class InceptionV3(nn.Module):
|
71 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
72 |
+
|
73 |
+
# Index of default block of inception to return,
|
74 |
+
# corresponds to output of final average pooling
|
75 |
+
DEFAULT_BLOCK_INDEX = 3
|
76 |
+
|
77 |
+
# Maps feature dimensionality to their output blocks indices
|
78 |
+
BLOCK_INDEX_BY_DIM = {
|
79 |
+
64: 0, # First max pooling features
|
80 |
+
192: 1, # Second max pooling featurs
|
81 |
+
768: 2, # Pre-aux classifier features
|
82 |
+
2048: 3 # Final average pooling features
|
83 |
+
}
|
84 |
+
|
85 |
+
def __init__(self,
|
86 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
87 |
+
resize_input=True,
|
88 |
+
normalize_input=True,
|
89 |
+
requires_grad=False,
|
90 |
+
use_fid_inception=True):
|
91 |
+
"""Build pretrained InceptionV3
|
92 |
+
|
93 |
+
Parameters
|
94 |
+
----------
|
95 |
+
output_blocks : list of int
|
96 |
+
Indices of blocks to return features of. Possible values are:
|
97 |
+
- 0: corresponds to output of first max pooling
|
98 |
+
- 1: corresponds to output of second max pooling
|
99 |
+
- 2: corresponds to output which is fed to aux classifier
|
100 |
+
- 3: corresponds to output of final average pooling
|
101 |
+
resize_input : bool
|
102 |
+
If true, bilinearly resizes input to width and height 299 before
|
103 |
+
feeding input to model. As the network without fully connected
|
104 |
+
layers is fully convolutional, it should be able to handle inputs
|
105 |
+
of arbitrary size, so resizing might not be strictly needed
|
106 |
+
normalize_input : bool
|
107 |
+
If true, scales the input from range (0, 1) to the range the
|
108 |
+
pretrained Inception network expects, namely (-1, 1)
|
109 |
+
requires_grad : bool
|
110 |
+
If true, parameters of the model require gradients. Possibly useful
|
111 |
+
for finetuning the network
|
112 |
+
use_fid_inception : bool
|
113 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
114 |
+
FID implementation. If false, uses the pretrained Inception model
|
115 |
+
available in torchvision. The FID Inception model has different
|
116 |
+
weights and a slightly different structure from torchvision's
|
117 |
+
Inception model. If you want to compute FID scores, you are
|
118 |
+
strongly advised to set this parameter to true to get comparable
|
119 |
+
results.
|
120 |
+
"""
|
121 |
+
super(InceptionV3, self).__init__()
|
122 |
+
|
123 |
+
self.resize_input = resize_input
|
124 |
+
self.normalize_input = normalize_input
|
125 |
+
self.output_blocks = sorted(output_blocks)
|
126 |
+
self.last_needed_block = max(output_blocks)
|
127 |
+
|
128 |
+
assert self.last_needed_block <= 3, \
|
129 |
+
'Last possible output block index is 3'
|
130 |
+
|
131 |
+
self.blocks = nn.ModuleList()
|
132 |
+
|
133 |
+
if use_fid_inception:
|
134 |
+
inception = fid_inception_v3()
|
135 |
+
else:
|
136 |
+
inception = _inception_v3(weights='DEFAULT')
|
137 |
+
|
138 |
+
# Block 0: input to maxpool1
|
139 |
+
block0 = [
|
140 |
+
inception.Conv2d_1a_3x3,
|
141 |
+
inception.Conv2d_2a_3x3,
|
142 |
+
inception.Conv2d_2b_3x3,
|
143 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
144 |
+
]
|
145 |
+
self.blocks.append(nn.Sequential(*block0))
|
146 |
+
|
147 |
+
# Block 1: maxpool1 to maxpool2
|
148 |
+
if self.last_needed_block >= 1:
|
149 |
+
block1 = [
|
150 |
+
inception.Conv2d_3b_1x1,
|
151 |
+
inception.Conv2d_4a_3x3,
|
152 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
153 |
+
]
|
154 |
+
self.blocks.append(nn.Sequential(*block1))
|
155 |
+
|
156 |
+
# Block 2: maxpool2 to aux classifier
|
157 |
+
if self.last_needed_block >= 2:
|
158 |
+
block2 = [
|
159 |
+
inception.Mixed_5b,
|
160 |
+
inception.Mixed_5c,
|
161 |
+
inception.Mixed_5d,
|
162 |
+
inception.Mixed_6a,
|
163 |
+
inception.Mixed_6b,
|
164 |
+
inception.Mixed_6c,
|
165 |
+
inception.Mixed_6d,
|
166 |
+
inception.Mixed_6e,
|
167 |
+
]
|
168 |
+
self.blocks.append(nn.Sequential(*block2))
|
169 |
+
|
170 |
+
# Block 3: aux classifier to final avgpool
|
171 |
+
if self.last_needed_block >= 3:
|
172 |
+
block3 = [
|
173 |
+
inception.Mixed_7a,
|
174 |
+
inception.Mixed_7b,
|
175 |
+
inception.Mixed_7c,
|
176 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
177 |
+
]
|
178 |
+
self.blocks.append(nn.Sequential(*block3))
|
179 |
+
|
180 |
+
for param in self.parameters():
|
181 |
+
param.requires_grad = requires_grad
|
182 |
+
|
183 |
+
def forward(self, inp):
|
184 |
+
"""Get Inception feature maps
|
185 |
+
|
186 |
+
Parameters
|
187 |
+
----------
|
188 |
+
inp : torch.autograd.Variable
|
189 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
190 |
+
range (0, 1)
|
191 |
+
|
192 |
+
Returns
|
193 |
+
-------
|
194 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
195 |
+
block, sorted ascending by index
|
196 |
+
"""
|
197 |
+
outp = []
|
198 |
+
x = inp
|
199 |
+
|
200 |
+
if self.resize_input:
|
201 |
+
x = F.interpolate(x,
|
202 |
+
size=(299, 299),
|
203 |
+
mode='bilinear',
|
204 |
+
align_corners=False)
|
205 |
+
|
206 |
+
if self.normalize_input:
|
207 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
208 |
+
|
209 |
+
for idx, block in enumerate(self.blocks):
|
210 |
+
x = block(x)
|
211 |
+
if idx in self.output_blocks:
|
212 |
+
outp.append(x)
|
213 |
+
|
214 |
+
if idx == self.last_needed_block:
|
215 |
+
break
|
216 |
+
|
217 |
+
return outp
|
218 |
+
|
219 |
+
|
220 |
+
def _inception_v3(*args, **kwargs):
|
221 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
222 |
+
try:
|
223 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
224 |
+
except ValueError:
|
225 |
+
# Just a caution against weird version strings
|
226 |
+
version = (0,)
|
227 |
+
|
228 |
+
# Skips default weight inititialization if supported by torchvision
|
229 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
230 |
+
if version >= (0, 6):
|
231 |
+
kwargs['init_weights'] = False
|
232 |
+
|
233 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
234 |
+
# argument prior to version 0.13.
|
235 |
+
if version < (0, 13) and 'weights' in kwargs:
|
236 |
+
if kwargs['weights'] == 'DEFAULT':
|
237 |
+
kwargs['pretrained'] = True
|
238 |
+
elif kwargs['weights'] is None:
|
239 |
+
kwargs['pretrained'] = False
|
240 |
+
else:
|
241 |
+
raise ValueError(
|
242 |
+
'weights=={} not supported in torchvision {}'.format(
|
243 |
+
kwargs['weights'], torchvision.__version__
|
244 |
+
)
|
245 |
+
)
|
246 |
+
del kwargs['weights']
|
247 |
+
|
248 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
249 |
+
|
250 |
+
|
251 |
+
def fid_inception_v3():
|
252 |
+
"""Build pretrained Inception model for FID computation
|
253 |
+
|
254 |
+
The Inception model for FID computation uses a different set of weights
|
255 |
+
and has a slightly different structure than torchvision's Inception.
|
256 |
+
|
257 |
+
This method first constructs torchvision's Inception and then patches the
|
258 |
+
necessary parts that are different in the FID Inception model.
|
259 |
+
"""
|
260 |
+
inception = _inception_v3(num_classes=1008,
|
261 |
+
aux_logits=False,
|
262 |
+
weights=None)
|
263 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
264 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
265 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
266 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
267 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
268 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
269 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
270 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
271 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
272 |
+
|
273 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
274 |
+
inception.load_state_dict(state_dict)
|
275 |
+
return inception
|
276 |
+
|
277 |
+
|
278 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
279 |
+
"""InceptionA block patched for FID computation"""
|
280 |
+
def __init__(self, in_channels, pool_features):
|
281 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
282 |
+
|
283 |
+
def forward(self, x):
|
284 |
+
branch1x1 = self.branch1x1(x)
|
285 |
+
|
286 |
+
branch5x5 = self.branch5x5_1(x)
|
287 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
288 |
+
|
289 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
290 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
291 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
292 |
+
|
293 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
294 |
+
# its average calculation
|
295 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
296 |
+
count_include_pad=False)
|
297 |
+
branch_pool = self.branch_pool(branch_pool)
|
298 |
+
|
299 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
300 |
+
return torch.cat(outputs, 1)
|
301 |
+
|
302 |
+
|
303 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
304 |
+
"""InceptionC block patched for FID computation"""
|
305 |
+
def __init__(self, in_channels, channels_7x7):
|
306 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
307 |
+
|
308 |
+
def forward(self, x):
|
309 |
+
branch1x1 = self.branch1x1(x)
|
310 |
+
|
311 |
+
branch7x7 = self.branch7x7_1(x)
|
312 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
313 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
314 |
+
|
315 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
316 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
317 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
318 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
319 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
320 |
+
|
321 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
322 |
+
# its average calculation
|
323 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
324 |
+
count_include_pad=False)
|
325 |
+
branch_pool = self.branch_pool(branch_pool)
|
326 |
+
|
327 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
328 |
+
return torch.cat(outputs, 1)
|
329 |
+
|
330 |
+
|
331 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
332 |
+
"""First InceptionE block patched for FID computation"""
|
333 |
+
def __init__(self, in_channels):
|
334 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
335 |
+
|
336 |
+
def forward(self, x):
|
337 |
+
branch1x1 = self.branch1x1(x)
|
338 |
+
|
339 |
+
branch3x3 = self.branch3x3_1(x)
|
340 |
+
branch3x3 = [
|
341 |
+
self.branch3x3_2a(branch3x3),
|
342 |
+
self.branch3x3_2b(branch3x3),
|
343 |
+
]
|
344 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
345 |
+
|
346 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
347 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
348 |
+
branch3x3dbl = [
|
349 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
350 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
351 |
+
]
|
352 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
353 |
+
|
354 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
355 |
+
# its average calculation
|
356 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
357 |
+
count_include_pad=False)
|
358 |
+
branch_pool = self.branch_pool(branch_pool)
|
359 |
+
|
360 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
361 |
+
return torch.cat(outputs, 1)
|
362 |
+
|
363 |
+
|
364 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
365 |
+
"""Second InceptionE block patched for FID computation"""
|
366 |
+
def __init__(self, in_channels):
|
367 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
368 |
+
|
369 |
+
def forward(self, x):
|
370 |
+
branch1x1 = self.branch1x1(x)
|
371 |
+
|
372 |
+
branch3x3 = self.branch3x3_1(x)
|
373 |
+
branch3x3 = [
|
374 |
+
self.branch3x3_2a(branch3x3),
|
375 |
+
self.branch3x3_2b(branch3x3),
|
376 |
+
]
|
377 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
378 |
+
|
379 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
380 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
381 |
+
branch3x3dbl = [
|
382 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
383 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
384 |
+
]
|
385 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
386 |
+
|
387 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
388 |
+
# pooling. This is likely an error in this specific Inception
|
389 |
+
# implementation, as other Inception models use average pooling here
|
390 |
+
# (which matches the description in the paper).
|
391 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
392 |
+
branch_pool = self.branch_pool(branch_pool)
|
393 |
+
|
394 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
395 |
+
return torch.cat(outputs, 1)
|
396 |
+
|
397 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
398 |
+
def __init__(self, files, transforms=None):
|
399 |
+
self.files = files
|
400 |
+
self.transforms = transforms
|
401 |
+
|
402 |
+
def __len__(self):
|
403 |
+
return len(self.files)
|
404 |
+
|
405 |
+
def __getitem__(self, i):
|
406 |
+
path = self.files[i]
|
407 |
+
img = Image.open(path).convert('RGB')
|
408 |
+
if self.transforms is not None:
|
409 |
+
img = self.transforms(img)
|
410 |
+
return img
|
411 |
+
|
412 |
+
|
413 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
|
414 |
+
num_workers=1, resize=0):
|
415 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
416 |
+
|
417 |
+
Params:
|
418 |
+
-- files : List of image files paths
|
419 |
+
-- model : Instance of inception model
|
420 |
+
-- batch_size : Batch size of images for the model to process at once.
|
421 |
+
Make sure that the number of samples is a multiple of
|
422 |
+
the batch size, otherwise some samples are ignored. This
|
423 |
+
behavior is retained to match the original FID score
|
424 |
+
implementation.
|
425 |
+
-- dims : Dimensionality of features returned by Inception
|
426 |
+
-- device : Device to run calculations
|
427 |
+
-- num_workers : Number of parallel dataloader workers
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
431 |
+
activations of the given tensor when feeding inception with the
|
432 |
+
query tensor.
|
433 |
+
"""
|
434 |
+
model.eval()
|
435 |
+
|
436 |
+
if batch_size > len(files):
|
437 |
+
print(('Warning: batch size is bigger than the data size. '
|
438 |
+
'Setting batch size to data size'))
|
439 |
+
batch_size = len(files)
|
440 |
+
if resize > 0:
|
441 |
+
tform = TF.Compose([TF.Resize((resize, resize)), TF.ToTensor()])
|
442 |
+
else:
|
443 |
+
tform = TF.ToTensor()
|
444 |
+
dataset = ImagePathDataset(files, transforms=tform)
|
445 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
446 |
+
batch_size=batch_size,
|
447 |
+
shuffle=False,
|
448 |
+
drop_last=False,
|
449 |
+
num_workers=num_workers)
|
450 |
+
|
451 |
+
pred_arr = np.empty((len(files), dims))
|
452 |
+
|
453 |
+
start_idx = 0
|
454 |
+
|
455 |
+
for batch in tqdm(dataloader):
|
456 |
+
batch = batch.to(device)
|
457 |
+
|
458 |
+
with torch.no_grad():
|
459 |
+
pred = model(batch)[0]
|
460 |
+
|
461 |
+
# If model output is not scalar, apply global spatial average pooling.
|
462 |
+
# This happens if you choose a dimensionality not equal 2048.
|
463 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
464 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
465 |
+
|
466 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
467 |
+
|
468 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
469 |
+
|
470 |
+
start_idx = start_idx + pred.shape[0]
|
471 |
+
|
472 |
+
return pred_arr
|
473 |
+
|
474 |
+
|
475 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
476 |
+
"""Numpy implementation of the Frechet Distance.
|
477 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
478 |
+
and X_2 ~ N(mu_2, C_2) is
|
479 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
480 |
+
|
481 |
+
Stable version by Dougal J. Sutherland.
|
482 |
+
|
483 |
+
Params:
|
484 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
485 |
+
inception net (like returned by the function 'get_predictions')
|
486 |
+
for generated samples.
|
487 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
488 |
+
representative data set.
|
489 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
490 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
491 |
+
representative data set.
|
492 |
+
|
493 |
+
Returns:
|
494 |
+
-- : The Frechet Distance.
|
495 |
+
"""
|
496 |
+
|
497 |
+
mu1 = np.atleast_1d(mu1)
|
498 |
+
mu2 = np.atleast_1d(mu2)
|
499 |
+
|
500 |
+
sigma1 = np.atleast_2d(sigma1)
|
501 |
+
sigma2 = np.atleast_2d(sigma2)
|
502 |
+
|
503 |
+
assert mu1.shape == mu2.shape, \
|
504 |
+
'Training and test mean vectors have different lengths'
|
505 |
+
assert sigma1.shape == sigma2.shape, \
|
506 |
+
'Training and test covariances have different dimensions'
|
507 |
+
|
508 |
+
diff = mu1 - mu2
|
509 |
+
|
510 |
+
# Product might be almost singular
|
511 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
512 |
+
if not np.isfinite(covmean).all():
|
513 |
+
msg = ('fid calculation produces singular product; '
|
514 |
+
'adding %s to diagonal of cov estimates') % eps
|
515 |
+
print(msg)
|
516 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
517 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
518 |
+
|
519 |
+
# Numerical error might give slight imaginary component
|
520 |
+
if np.iscomplexobj(covmean):
|
521 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
522 |
+
m = np.max(np.abs(covmean.imag))
|
523 |
+
raise ValueError('Imaginary component {}'.format(m))
|
524 |
+
covmean = covmean.real
|
525 |
+
|
526 |
+
tr_covmean = np.trace(covmean)
|
527 |
+
|
528 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
529 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
530 |
+
|
531 |
+
|
532 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
533 |
+
device='cpu', num_workers=1, resize=0):
|
534 |
+
"""Calculation of the statistics used by the FID.
|
535 |
+
Params:
|
536 |
+
-- files : List of image files paths
|
537 |
+
-- model : Instance of inception model
|
538 |
+
-- batch_size : The images numpy array is split into batches with
|
539 |
+
batch size batch_size. A reasonable batch size
|
540 |
+
depends on the hardware.
|
541 |
+
-- dims : Dimensionality of features returned by Inception
|
542 |
+
-- device : Device to run calculations
|
543 |
+
-- num_workers : Number of parallel dataloader workers
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
547 |
+
the inception model.
|
548 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
549 |
+
the inception model.
|
550 |
+
"""
|
551 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers, resize)
|
552 |
+
mu = np.mean(act, axis=0)
|
553 |
+
sigma = np.cov(act, rowvar=False)
|
554 |
+
return mu, sigma
|
555 |
+
|
556 |
+
|
557 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device,
|
558 |
+
num_workers=1, nimages=None, resize=0):
|
559 |
+
if path.endswith('.npz'):
|
560 |
+
with np.load(path) as f:
|
561 |
+
m, s = f['mu'][:], f['sigma'][:]
|
562 |
+
else:
|
563 |
+
path = pathlib.Path(path)
|
564 |
+
|
565 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
566 |
+
for file in path.glob('**/*.{}'.format(ext))])
|
567 |
+
nfiles = len(files)
|
568 |
+
n = nfiles if nimages is None else min(nimages, nfiles)
|
569 |
+
print(f'Found {nfiles} images. Computing FID with {n} images.')
|
570 |
+
files = files[:n]
|
571 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
572 |
+
dims, device, num_workers, resize)
|
573 |
+
|
574 |
+
return m, s
|
575 |
+
|
576 |
+
|
577 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
|
578 |
+
"""Calculates the FID of two paths"""
|
579 |
+
for p in paths:
|
580 |
+
if not os.path.exists(p):
|
581 |
+
raise RuntimeError('Invalid path: %s' % p)
|
582 |
+
|
583 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
584 |
+
|
585 |
+
model = InceptionV3([block_idx]).to(device)
|
586 |
+
|
587 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
588 |
+
dims, device, num_workers, nimages, resize)
|
589 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
590 |
+
dims, device, num_workers, nimages, resize)
|
591 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
592 |
+
|
593 |
+
return fid_value
|
594 |
+
|
595 |
+
|
596 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
|
597 |
+
"""Calculates the FID of two paths"""
|
598 |
+
if not os.path.exists(paths[0]):
|
599 |
+
raise RuntimeError('Invalid path: %s' % paths[0])
|
600 |
+
|
601 |
+
if os.path.exists(paths[1]):
|
602 |
+
raise RuntimeError('Existing output file: %s' % paths[1])
|
603 |
+
|
604 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
605 |
+
|
606 |
+
model = InceptionV3([block_idx]).to(device)
|
607 |
+
|
608 |
+
print(f"Saving statistics for {paths[0]}")
|
609 |
+
|
610 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
611 |
+
dims, device, num_workers, nimages, resize=0)
|
612 |
+
|
613 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
614 |
+
|
615 |
+
|
616 |
+
def main():
|
617 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
618 |
+
parser.add_argument('--batch-size', type=int, default=20,
|
619 |
+
help='Batch size to use')
|
620 |
+
parser.add_argument('--num-workers', type=int,
|
621 |
+
help=('Number of processes to use for data loading. '
|
622 |
+
'Defaults to `min(8, num_cpus)`'))
|
623 |
+
parser.add_argument('--device', type=str, default='cuda:0',
|
624 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
625 |
+
parser.add_argument('--dims', type=int, default=2048,
|
626 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
627 |
+
help=('Dimensionality of Inception features to use. '
|
628 |
+
'By default, uses pool3 features'))
|
629 |
+
parser.add_argument('--nimages', type=int, default=50000, help='max number of images to use')
|
630 |
+
parser.add_argument('--resize', type=int, default=0, help='resize images to this size, 0 mean keep original size')
|
631 |
+
parser.add_argument('--save-stats', action='store_true',
|
632 |
+
help=('Generate an npz archive from a directory of samples. '
|
633 |
+
'The first path is used as input and the second as output.'))
|
634 |
+
parser.add_argument('path', type=str, nargs=2,
|
635 |
+
help=('Paths to the generated images or '
|
636 |
+
'to .npz statistic files'))
|
637 |
+
args = parser.parse_args()
|
638 |
+
|
639 |
+
if args.device is None:
|
640 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
641 |
+
else:
|
642 |
+
device = torch.device(args.device)
|
643 |
+
|
644 |
+
if args.num_workers is None:
|
645 |
+
try:
|
646 |
+
num_cpus = len(os.sched_getaffinity(0))
|
647 |
+
except AttributeError:
|
648 |
+
# os.sched_getaffinity is not available under Windows, use
|
649 |
+
# os.cpu_count instead (which may not return the *available* number
|
650 |
+
# of CPUs).
|
651 |
+
num_cpus = os.cpu_count()
|
652 |
+
|
653 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
654 |
+
else:
|
655 |
+
num_workers = args.num_workers
|
656 |
+
|
657 |
+
if args.save_stats:
|
658 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers, args.nimages, args.resize)
|
659 |
+
return
|
660 |
+
|
661 |
+
fid_value = calculate_fid_given_paths(args.path,
|
662 |
+
args.batch_size,
|
663 |
+
device,
|
664 |
+
args.dims,
|
665 |
+
num_workers,
|
666 |
+
args.nimages,
|
667 |
+
args.resize)
|
668 |
+
print('FID: ', fid_value)
|
669 |
+
|
670 |
+
|
671 |
+
if __name__ == '__main__':
|
672 |
+
main()
|
tools/fid_lmdb.py
ADDED
@@ -0,0 +1,683 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
2 |
+
|
3 |
+
The FID metric calculates the distance between two distributions of images.
|
4 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
5 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
6 |
+
|
7 |
+
When run as a stand-alone program, it compares the distribution of
|
8 |
+
images that are stored as PNG/JPEG at a specified location with a
|
9 |
+
distribution given by summary statistics (in pickle format).
|
10 |
+
|
11 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
12 |
+
the pool_3 layer of the inception net for generated samples and real world
|
13 |
+
samples respectively.
|
14 |
+
|
15 |
+
See --help to see further details.
|
16 |
+
|
17 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
18 |
+
of Tensorflow
|
19 |
+
|
20 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
21 |
+
|
22 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
23 |
+
you may not use this file except in compliance with the License.
|
24 |
+
You may obtain a copy of the License at
|
25 |
+
|
26 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
27 |
+
|
28 |
+
Unless required by applicable law or agreed to in writing, software
|
29 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
30 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
31 |
+
See the License for the specific language governing permissions and
|
32 |
+
limitations under the License.
|
33 |
+
"""
|
34 |
+
import os
|
35 |
+
import pathlib
|
36 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
37 |
+
|
38 |
+
import numpy as np
|
39 |
+
import torch
|
40 |
+
import torchvision.transforms as TF
|
41 |
+
from PIL import Image
|
42 |
+
from scipy import linalg
|
43 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
44 |
+
import torch.nn as nn
|
45 |
+
import torch.nn.functional as F
|
46 |
+
import torchvision
|
47 |
+
import sys
|
48 |
+
sys.path.insert(1, '/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet')
|
49 |
+
from tools.image_dataset import ImageDataset
|
50 |
+
try:
|
51 |
+
from tqdm import tqdm
|
52 |
+
except ImportError:
|
53 |
+
# If tqdm is not available, provide a mock version of it
|
54 |
+
def tqdm(x):
|
55 |
+
return x
|
56 |
+
|
57 |
+
|
58 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
59 |
+
'tif', 'tiff', 'webp'}
|
60 |
+
|
61 |
+
|
62 |
+
try:
|
63 |
+
from torchvision.models.utils import load_state_dict_from_url
|
64 |
+
except ImportError:
|
65 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
66 |
+
|
67 |
+
# Inception weights ported to Pytorch from
|
68 |
+
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
69 |
+
FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
|
70 |
+
|
71 |
+
|
72 |
+
class InceptionV3(nn.Module):
|
73 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
74 |
+
|
75 |
+
# Index of default block of inception to return,
|
76 |
+
# corresponds to output of final average pooling
|
77 |
+
DEFAULT_BLOCK_INDEX = 3
|
78 |
+
|
79 |
+
# Maps feature dimensionality to their output blocks indices
|
80 |
+
BLOCK_INDEX_BY_DIM = {
|
81 |
+
64: 0, # First max pooling features
|
82 |
+
192: 1, # Second max pooling featurs
|
83 |
+
768: 2, # Pre-aux classifier features
|
84 |
+
2048: 3 # Final average pooling features
|
85 |
+
}
|
86 |
+
|
87 |
+
def __init__(self,
|
88 |
+
output_blocks=(DEFAULT_BLOCK_INDEX,),
|
89 |
+
resize_input=True,
|
90 |
+
normalize_input=True,
|
91 |
+
requires_grad=False,
|
92 |
+
use_fid_inception=True):
|
93 |
+
"""Build pretrained InceptionV3
|
94 |
+
|
95 |
+
Parameters
|
96 |
+
----------
|
97 |
+
output_blocks : list of int
|
98 |
+
Indices of blocks to return features of. Possible values are:
|
99 |
+
- 0: corresponds to output of first max pooling
|
100 |
+
- 1: corresponds to output of second max pooling
|
101 |
+
- 2: corresponds to output which is fed to aux classifier
|
102 |
+
- 3: corresponds to output of final average pooling
|
103 |
+
resize_input : bool
|
104 |
+
If true, bilinearly resizes input to width and height 299 before
|
105 |
+
feeding input to model. As the network without fully connected
|
106 |
+
layers is fully convolutional, it should be able to handle inputs
|
107 |
+
of arbitrary size, so resizing might not be strictly needed
|
108 |
+
normalize_input : bool
|
109 |
+
If true, scales the input from range (0, 1) to the range the
|
110 |
+
pretrained Inception network expects, namely (-1, 1)
|
111 |
+
requires_grad : bool
|
112 |
+
If true, parameters of the model require gradients. Possibly useful
|
113 |
+
for finetuning the network
|
114 |
+
use_fid_inception : bool
|
115 |
+
If true, uses the pretrained Inception model used in Tensorflow's
|
116 |
+
FID implementation. If false, uses the pretrained Inception model
|
117 |
+
available in torchvision. The FID Inception model has different
|
118 |
+
weights and a slightly different structure from torchvision's
|
119 |
+
Inception model. If you want to compute FID scores, you are
|
120 |
+
strongly advised to set this parameter to true to get comparable
|
121 |
+
results.
|
122 |
+
"""
|
123 |
+
super(InceptionV3, self).__init__()
|
124 |
+
|
125 |
+
self.resize_input = resize_input
|
126 |
+
self.normalize_input = normalize_input
|
127 |
+
self.output_blocks = sorted(output_blocks)
|
128 |
+
self.last_needed_block = max(output_blocks)
|
129 |
+
|
130 |
+
assert self.last_needed_block <= 3, \
|
131 |
+
'Last possible output block index is 3'
|
132 |
+
|
133 |
+
self.blocks = nn.ModuleList()
|
134 |
+
|
135 |
+
if use_fid_inception:
|
136 |
+
inception = fid_inception_v3()
|
137 |
+
else:
|
138 |
+
inception = _inception_v3(weights='DEFAULT')
|
139 |
+
|
140 |
+
# Block 0: input to maxpool1
|
141 |
+
block0 = [
|
142 |
+
inception.Conv2d_1a_3x3,
|
143 |
+
inception.Conv2d_2a_3x3,
|
144 |
+
inception.Conv2d_2b_3x3,
|
145 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
146 |
+
]
|
147 |
+
self.blocks.append(nn.Sequential(*block0))
|
148 |
+
|
149 |
+
# Block 1: maxpool1 to maxpool2
|
150 |
+
if self.last_needed_block >= 1:
|
151 |
+
block1 = [
|
152 |
+
inception.Conv2d_3b_1x1,
|
153 |
+
inception.Conv2d_4a_3x3,
|
154 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
155 |
+
]
|
156 |
+
self.blocks.append(nn.Sequential(*block1))
|
157 |
+
|
158 |
+
# Block 2: maxpool2 to aux classifier
|
159 |
+
if self.last_needed_block >= 2:
|
160 |
+
block2 = [
|
161 |
+
inception.Mixed_5b,
|
162 |
+
inception.Mixed_5c,
|
163 |
+
inception.Mixed_5d,
|
164 |
+
inception.Mixed_6a,
|
165 |
+
inception.Mixed_6b,
|
166 |
+
inception.Mixed_6c,
|
167 |
+
inception.Mixed_6d,
|
168 |
+
inception.Mixed_6e,
|
169 |
+
]
|
170 |
+
self.blocks.append(nn.Sequential(*block2))
|
171 |
+
|
172 |
+
# Block 3: aux classifier to final avgpool
|
173 |
+
if self.last_needed_block >= 3:
|
174 |
+
block3 = [
|
175 |
+
inception.Mixed_7a,
|
176 |
+
inception.Mixed_7b,
|
177 |
+
inception.Mixed_7c,
|
178 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
179 |
+
]
|
180 |
+
self.blocks.append(nn.Sequential(*block3))
|
181 |
+
|
182 |
+
for param in self.parameters():
|
183 |
+
param.requires_grad = requires_grad
|
184 |
+
|
185 |
+
def forward(self, inp):
|
186 |
+
"""Get Inception feature maps
|
187 |
+
|
188 |
+
Parameters
|
189 |
+
----------
|
190 |
+
inp : torch.autograd.Variable
|
191 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
192 |
+
range (0, 1)
|
193 |
+
|
194 |
+
Returns
|
195 |
+
-------
|
196 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
197 |
+
block, sorted ascending by index
|
198 |
+
"""
|
199 |
+
outp = []
|
200 |
+
x = inp
|
201 |
+
|
202 |
+
if self.resize_input:
|
203 |
+
x = F.interpolate(x,
|
204 |
+
size=(299, 299),
|
205 |
+
mode='bilinear',
|
206 |
+
align_corners=False)
|
207 |
+
|
208 |
+
if self.normalize_input:
|
209 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
210 |
+
|
211 |
+
for idx, block in enumerate(self.blocks):
|
212 |
+
x = block(x)
|
213 |
+
if idx in self.output_blocks:
|
214 |
+
outp.append(x)
|
215 |
+
|
216 |
+
if idx == self.last_needed_block:
|
217 |
+
break
|
218 |
+
|
219 |
+
return outp
|
220 |
+
|
221 |
+
|
222 |
+
def _inception_v3(*args, **kwargs):
|
223 |
+
"""Wraps `torchvision.models.inception_v3`"""
|
224 |
+
try:
|
225 |
+
version = tuple(map(int, torchvision.__version__.split('.')[:2]))
|
226 |
+
except ValueError:
|
227 |
+
# Just a caution against weird version strings
|
228 |
+
version = (0,)
|
229 |
+
|
230 |
+
# Skips default weight inititialization if supported by torchvision
|
231 |
+
# version. See https://github.com/mseitzer/pytorch-fid/issues/28.
|
232 |
+
if version >= (0, 6):
|
233 |
+
kwargs['init_weights'] = False
|
234 |
+
|
235 |
+
# Backwards compatibility: `weights` argument was handled by `pretrained`
|
236 |
+
# argument prior to version 0.13.
|
237 |
+
if version < (0, 13) and 'weights' in kwargs:
|
238 |
+
if kwargs['weights'] == 'DEFAULT':
|
239 |
+
kwargs['pretrained'] = True
|
240 |
+
elif kwargs['weights'] is None:
|
241 |
+
kwargs['pretrained'] = False
|
242 |
+
else:
|
243 |
+
raise ValueError(
|
244 |
+
'weights=={} not supported in torchvision {}'.format(
|
245 |
+
kwargs['weights'], torchvision.__version__
|
246 |
+
)
|
247 |
+
)
|
248 |
+
del kwargs['weights']
|
249 |
+
|
250 |
+
return torchvision.models.inception_v3(*args, **kwargs)
|
251 |
+
|
252 |
+
|
253 |
+
def fid_inception_v3():
|
254 |
+
"""Build pretrained Inception model for FID computation
|
255 |
+
|
256 |
+
The Inception model for FID computation uses a different set of weights
|
257 |
+
and has a slightly different structure than torchvision's Inception.
|
258 |
+
|
259 |
+
This method first constructs torchvision's Inception and then patches the
|
260 |
+
necessary parts that are different in the FID Inception model.
|
261 |
+
"""
|
262 |
+
inception = _inception_v3(num_classes=1008,
|
263 |
+
aux_logits=False,
|
264 |
+
weights=None)
|
265 |
+
inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
|
266 |
+
inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
|
267 |
+
inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
|
268 |
+
inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
|
269 |
+
inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
|
270 |
+
inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
|
271 |
+
inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
|
272 |
+
inception.Mixed_7b = FIDInceptionE_1(1280)
|
273 |
+
inception.Mixed_7c = FIDInceptionE_2(2048)
|
274 |
+
|
275 |
+
state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
|
276 |
+
inception.load_state_dict(state_dict)
|
277 |
+
return inception
|
278 |
+
|
279 |
+
|
280 |
+
class FIDInceptionA(torchvision.models.inception.InceptionA):
|
281 |
+
"""InceptionA block patched for FID computation"""
|
282 |
+
def __init__(self, in_channels, pool_features):
|
283 |
+
super(FIDInceptionA, self).__init__(in_channels, pool_features)
|
284 |
+
|
285 |
+
def forward(self, x):
|
286 |
+
branch1x1 = self.branch1x1(x)
|
287 |
+
|
288 |
+
branch5x5 = self.branch5x5_1(x)
|
289 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
290 |
+
|
291 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
292 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
293 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
294 |
+
|
295 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
296 |
+
# its average calculation
|
297 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
298 |
+
count_include_pad=False)
|
299 |
+
branch_pool = self.branch_pool(branch_pool)
|
300 |
+
|
301 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
302 |
+
return torch.cat(outputs, 1)
|
303 |
+
|
304 |
+
|
305 |
+
class FIDInceptionC(torchvision.models.inception.InceptionC):
|
306 |
+
"""InceptionC block patched for FID computation"""
|
307 |
+
def __init__(self, in_channels, channels_7x7):
|
308 |
+
super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
|
309 |
+
|
310 |
+
def forward(self, x):
|
311 |
+
branch1x1 = self.branch1x1(x)
|
312 |
+
|
313 |
+
branch7x7 = self.branch7x7_1(x)
|
314 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
315 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
316 |
+
|
317 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
318 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
319 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
320 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
321 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
322 |
+
|
323 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
324 |
+
# its average calculation
|
325 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
326 |
+
count_include_pad=False)
|
327 |
+
branch_pool = self.branch_pool(branch_pool)
|
328 |
+
|
329 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
330 |
+
return torch.cat(outputs, 1)
|
331 |
+
|
332 |
+
|
333 |
+
class FIDInceptionE_1(torchvision.models.inception.InceptionE):
|
334 |
+
"""First InceptionE block patched for FID computation"""
|
335 |
+
def __init__(self, in_channels):
|
336 |
+
super(FIDInceptionE_1, self).__init__(in_channels)
|
337 |
+
|
338 |
+
def forward(self, x):
|
339 |
+
branch1x1 = self.branch1x1(x)
|
340 |
+
|
341 |
+
branch3x3 = self.branch3x3_1(x)
|
342 |
+
branch3x3 = [
|
343 |
+
self.branch3x3_2a(branch3x3),
|
344 |
+
self.branch3x3_2b(branch3x3),
|
345 |
+
]
|
346 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
347 |
+
|
348 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
349 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
350 |
+
branch3x3dbl = [
|
351 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
352 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
353 |
+
]
|
354 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
355 |
+
|
356 |
+
# Patch: Tensorflow's average pool does not use the padded zero's in
|
357 |
+
# its average calculation
|
358 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
359 |
+
count_include_pad=False)
|
360 |
+
branch_pool = self.branch_pool(branch_pool)
|
361 |
+
|
362 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
363 |
+
return torch.cat(outputs, 1)
|
364 |
+
|
365 |
+
|
366 |
+
class FIDInceptionE_2(torchvision.models.inception.InceptionE):
|
367 |
+
"""Second InceptionE block patched for FID computation"""
|
368 |
+
def __init__(self, in_channels):
|
369 |
+
super(FIDInceptionE_2, self).__init__(in_channels)
|
370 |
+
|
371 |
+
def forward(self, x):
|
372 |
+
branch1x1 = self.branch1x1(x)
|
373 |
+
|
374 |
+
branch3x3 = self.branch3x3_1(x)
|
375 |
+
branch3x3 = [
|
376 |
+
self.branch3x3_2a(branch3x3),
|
377 |
+
self.branch3x3_2b(branch3x3),
|
378 |
+
]
|
379 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
380 |
+
|
381 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
382 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
383 |
+
branch3x3dbl = [
|
384 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
385 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
386 |
+
]
|
387 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
388 |
+
|
389 |
+
# Patch: The FID Inception model uses max pooling instead of average
|
390 |
+
# pooling. This is likely an error in this specific Inception
|
391 |
+
# implementation, as other Inception models use average pooling here
|
392 |
+
# (which matches the description in the paper).
|
393 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
394 |
+
branch_pool = self.branch_pool(branch_pool)
|
395 |
+
|
396 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
397 |
+
return torch.cat(outputs, 1)
|
398 |
+
|
399 |
+
class ImagePathDataset(torch.utils.data.Dataset):
|
400 |
+
def __init__(self, files, transforms=None):
|
401 |
+
self.files = files
|
402 |
+
self.transforms = transforms
|
403 |
+
|
404 |
+
def __len__(self):
|
405 |
+
return len(self.files)
|
406 |
+
|
407 |
+
def __getitem__(self, i):
|
408 |
+
path = self.files[i]
|
409 |
+
img = Image.open(path).convert('RGB')
|
410 |
+
if self.transforms is not None:
|
411 |
+
img = self.transforms(img)
|
412 |
+
return img
|
413 |
+
|
414 |
+
|
415 |
+
def get_activations(files, model, batch_size=50, dims=2048, device='cpu',
|
416 |
+
num_workers=1, resize=0):
|
417 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
418 |
+
|
419 |
+
Params:
|
420 |
+
-- files : List of image files paths
|
421 |
+
-- model : Instance of inception model
|
422 |
+
-- batch_size : Batch size of images for the model to process at once.
|
423 |
+
Make sure that the number of samples is a multiple of
|
424 |
+
the batch size, otherwise some samples are ignored. This
|
425 |
+
behavior is retained to match the original FID score
|
426 |
+
implementation.
|
427 |
+
-- dims : Dimensionality of features returned by Inception
|
428 |
+
-- device : Device to run calculations
|
429 |
+
-- num_workers : Number of parallel dataloader workers
|
430 |
+
|
431 |
+
Returns:
|
432 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
433 |
+
activations of the given tensor when feeding inception with the
|
434 |
+
query tensor.
|
435 |
+
"""
|
436 |
+
model.eval()
|
437 |
+
|
438 |
+
if batch_size > len(files):
|
439 |
+
print(('Warning: batch size is bigger than the data size. '
|
440 |
+
'Setting batch size to data size'))
|
441 |
+
batch_size = len(files)
|
442 |
+
if resize > 0:
|
443 |
+
tform = TF.Compose([TF.Resize((resize, resize)), TF.ToTensor()])
|
444 |
+
else:
|
445 |
+
tform = TF.ToTensor()
|
446 |
+
if isinstance(files, list):
|
447 |
+
dataset = ImagePathDataset(files, transforms=tform)
|
448 |
+
else:
|
449 |
+
files.set_transform(tform)
|
450 |
+
dataset = files
|
451 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
452 |
+
batch_size=batch_size,
|
453 |
+
shuffle=False,
|
454 |
+
drop_last=False,
|
455 |
+
num_workers=num_workers)
|
456 |
+
|
457 |
+
pred_arr = np.empty((len(files), dims))
|
458 |
+
|
459 |
+
start_idx = 0
|
460 |
+
|
461 |
+
for batch in tqdm(dataloader):
|
462 |
+
batch = batch['image'].to(device)
|
463 |
+
|
464 |
+
with torch.no_grad():
|
465 |
+
pred = model(batch)[0]
|
466 |
+
|
467 |
+
# If model output is not scalar, apply global spatial average pooling.
|
468 |
+
# This happens if you choose a dimensionality not equal 2048.
|
469 |
+
if pred.size(2) != 1 or pred.size(3) != 1:
|
470 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
471 |
+
|
472 |
+
pred = pred.squeeze(3).squeeze(2).cpu().numpy()
|
473 |
+
|
474 |
+
pred_arr[start_idx:start_idx + pred.shape[0]] = pred
|
475 |
+
|
476 |
+
start_idx = start_idx + pred.shape[0]
|
477 |
+
|
478 |
+
return pred_arr
|
479 |
+
|
480 |
+
|
481 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
482 |
+
"""Numpy implementation of the Frechet Distance.
|
483 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
484 |
+
and X_2 ~ N(mu_2, C_2) is
|
485 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
486 |
+
|
487 |
+
Stable version by Dougal J. Sutherland.
|
488 |
+
|
489 |
+
Params:
|
490 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
491 |
+
inception net (like returned by the function 'get_predictions')
|
492 |
+
for generated samples.
|
493 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
494 |
+
representative data set.
|
495 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
496 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
497 |
+
representative data set.
|
498 |
+
|
499 |
+
Returns:
|
500 |
+
-- : The Frechet Distance.
|
501 |
+
"""
|
502 |
+
|
503 |
+
mu1 = np.atleast_1d(mu1)
|
504 |
+
mu2 = np.atleast_1d(mu2)
|
505 |
+
|
506 |
+
sigma1 = np.atleast_2d(sigma1)
|
507 |
+
sigma2 = np.atleast_2d(sigma2)
|
508 |
+
|
509 |
+
assert mu1.shape == mu2.shape, \
|
510 |
+
'Training and test mean vectors have different lengths'
|
511 |
+
assert sigma1.shape == sigma2.shape, \
|
512 |
+
'Training and test covariances have different dimensions'
|
513 |
+
|
514 |
+
diff = mu1 - mu2
|
515 |
+
|
516 |
+
# Product might be almost singular
|
517 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
518 |
+
if not np.isfinite(covmean).all():
|
519 |
+
msg = ('fid calculation produces singular product; '
|
520 |
+
'adding %s to diagonal of cov estimates') % eps
|
521 |
+
print(msg)
|
522 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
523 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
524 |
+
|
525 |
+
# Numerical error might give slight imaginary component
|
526 |
+
if np.iscomplexobj(covmean):
|
527 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
528 |
+
m = np.max(np.abs(covmean.imag))
|
529 |
+
raise ValueError('Imaginary component {}'.format(m))
|
530 |
+
covmean = covmean.real
|
531 |
+
|
532 |
+
tr_covmean = np.trace(covmean)
|
533 |
+
|
534 |
+
return (diff.dot(diff) + np.trace(sigma1)
|
535 |
+
+ np.trace(sigma2) - 2 * tr_covmean)
|
536 |
+
|
537 |
+
|
538 |
+
def calculate_activation_statistics(files, model, batch_size=50, dims=2048,
|
539 |
+
device='cpu', num_workers=1, resize=0):
|
540 |
+
"""Calculation of the statistics used by the FID.
|
541 |
+
Params:
|
542 |
+
-- files : List of image files paths
|
543 |
+
-- model : Instance of inception model
|
544 |
+
-- batch_size : The images numpy array is split into batches with
|
545 |
+
batch size batch_size. A reasonable batch size
|
546 |
+
depends on the hardware.
|
547 |
+
-- dims : Dimensionality of features returned by Inception
|
548 |
+
-- device : Device to run calculations
|
549 |
+
-- num_workers : Number of parallel dataloader workers
|
550 |
+
|
551 |
+
Returns:
|
552 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
553 |
+
the inception model.
|
554 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
555 |
+
the inception model.
|
556 |
+
"""
|
557 |
+
act = get_activations(files, model, batch_size, dims, device, num_workers, resize)
|
558 |
+
mu = np.mean(act, axis=0)
|
559 |
+
sigma = np.cov(act, rowvar=False)
|
560 |
+
return mu, sigma
|
561 |
+
|
562 |
+
|
563 |
+
def compute_statistics_of_path(path, model, batch_size, dims, device,
|
564 |
+
num_workers=1, nimages=None, resize=0):
|
565 |
+
if path.endswith('.npz'):
|
566 |
+
with np.load(path) as f:
|
567 |
+
m, s = f['mu'][:], f['sigma'][:]
|
568 |
+
else:
|
569 |
+
path = pathlib.Path(path)
|
570 |
+
if (path/'data.mdb').exists():
|
571 |
+
files = ImageDataset(path, None)
|
572 |
+
nfiles = len(files)
|
573 |
+
n = nfiles if nimages is None else min(nimages, nfiles)
|
574 |
+
files.set_ids(range(n))
|
575 |
+
else:
|
576 |
+
files = sorted([file for ext in IMAGE_EXTENSIONS
|
577 |
+
for file in path.glob('**/*.{}'.format(ext))])
|
578 |
+
nfiles = len(files)
|
579 |
+
n = nfiles if nimages is None else min(nimages, nfiles)
|
580 |
+
files = files[:n]
|
581 |
+
print(f'Found {nfiles} images. Computing FID with {n} images.')
|
582 |
+
m, s = calculate_activation_statistics(files, model, batch_size,
|
583 |
+
dims, device, num_workers, resize)
|
584 |
+
|
585 |
+
return m, s
|
586 |
+
|
587 |
+
|
588 |
+
def calculate_fid_given_paths(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
|
589 |
+
"""Calculates the FID of two paths"""
|
590 |
+
for p in paths:
|
591 |
+
if not os.path.exists(p):
|
592 |
+
raise RuntimeError('Invalid path: %s' % p)
|
593 |
+
|
594 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
595 |
+
|
596 |
+
model = InceptionV3([block_idx]).to(device)
|
597 |
+
|
598 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
599 |
+
dims, device, num_workers, nimages, resize)
|
600 |
+
m2, s2 = compute_statistics_of_path(paths[1], model, batch_size,
|
601 |
+
dims, device, num_workers, nimages, resize)
|
602 |
+
fid_value = calculate_frechet_distance(m1, s1, m2, s2)
|
603 |
+
|
604 |
+
return fid_value
|
605 |
+
|
606 |
+
|
607 |
+
def save_fid_stats(paths, batch_size, device, dims, num_workers=1, nimages=None, resize=0):
|
608 |
+
"""Calculates the FID of two paths"""
|
609 |
+
if not os.path.exists(paths[0]):
|
610 |
+
raise RuntimeError('Invalid path: %s' % paths[0])
|
611 |
+
|
612 |
+
if os.path.exists(paths[1]):
|
613 |
+
raise RuntimeError('Existing output file: %s' % paths[1])
|
614 |
+
|
615 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
616 |
+
|
617 |
+
model = InceptionV3([block_idx]).to(device)
|
618 |
+
|
619 |
+
print(f"Saving statistics for {paths[0]}")
|
620 |
+
|
621 |
+
m1, s1 = compute_statistics_of_path(paths[0], model, batch_size,
|
622 |
+
dims, device, num_workers, nimages, resize=0)
|
623 |
+
|
624 |
+
np.savez_compressed(paths[1], mu=m1, sigma=s1)
|
625 |
+
|
626 |
+
|
627 |
+
def main():
|
628 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
629 |
+
parser.add_argument('--batch-size', type=int, default=20,
|
630 |
+
help='Batch size to use')
|
631 |
+
parser.add_argument('--num-workers', type=int,
|
632 |
+
help=('Number of processes to use for data loading. '
|
633 |
+
'Defaults to `min(8, num_cpus)`'))
|
634 |
+
parser.add_argument('--device', type=str, default='cuda:0',
|
635 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
636 |
+
parser.add_argument('--dims', type=int, default=2048,
|
637 |
+
choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
|
638 |
+
help=('Dimensionality of Inception features to use. '
|
639 |
+
'By default, uses pool3 features'))
|
640 |
+
parser.add_argument('--nimages', type=int, default=50000, help='max number of images to use')
|
641 |
+
parser.add_argument('--resize', type=int, default=0, help='resize images to this size, 0 mean keep original size')
|
642 |
+
parser.add_argument('--save-stats', action='store_true',
|
643 |
+
help=('Generate an npz archive from a directory of samples. '
|
644 |
+
'The first path is used as input and the second as output.'))
|
645 |
+
parser.add_argument('path', type=str, nargs=2,
|
646 |
+
help=('Paths to the generated images or '
|
647 |
+
'to .npz statistic files'))
|
648 |
+
args = parser.parse_args()
|
649 |
+
|
650 |
+
if args.device is None:
|
651 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
652 |
+
else:
|
653 |
+
device = torch.device(args.device)
|
654 |
+
|
655 |
+
if args.num_workers is None:
|
656 |
+
try:
|
657 |
+
num_cpus = len(os.sched_getaffinity(0))
|
658 |
+
except AttributeError:
|
659 |
+
# os.sched_getaffinity is not available under Windows, use
|
660 |
+
# os.cpu_count instead (which may not return the *available* number
|
661 |
+
# of CPUs).
|
662 |
+
num_cpus = os.cpu_count()
|
663 |
+
|
664 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
665 |
+
else:
|
666 |
+
num_workers = args.num_workers
|
667 |
+
|
668 |
+
if args.save_stats:
|
669 |
+
save_fid_stats(args.path, args.batch_size, device, args.dims, num_workers, args.nimages, args.resize)
|
670 |
+
return
|
671 |
+
|
672 |
+
fid_value = calculate_fid_given_paths(args.path,
|
673 |
+
args.batch_size,
|
674 |
+
device,
|
675 |
+
args.dims,
|
676 |
+
num_workers,
|
677 |
+
args.nimages,
|
678 |
+
args.resize)
|
679 |
+
print('FID: ', fid_value)
|
680 |
+
|
681 |
+
|
682 |
+
if __name__ == '__main__':
|
683 |
+
main()
|
tools/gradcam.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
gradcam visualisation for each GAN class
|
5 |
+
@author: Tu Bui @surrey.ac.uk
|
6 |
+
"""
|
7 |
+
from __future__ import absolute_import
|
8 |
+
from __future__ import division
|
9 |
+
from __future__ import print_function
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
import inspect
|
13 |
+
import argparse
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
import matplotlib
|
17 |
+
matplotlib.use('Agg')
|
18 |
+
import matplotlib.pyplot as plt
|
19 |
+
import cv2
|
20 |
+
from PIL import Image, ImageDraw, ImageFont
|
21 |
+
import torch
|
22 |
+
import torchvision
|
23 |
+
from torch.autograd import Function
|
24 |
+
import torch.nn.functional as F
|
25 |
+
|
26 |
+
|
27 |
+
def show_cam_on_image(img, cam, cmap='jet'):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
img PIL image (H,W,3)
|
31 |
+
cam heatmap (H, W), range [0,1]
|
32 |
+
Returns:
|
33 |
+
PIL image with heatmap applied.
|
34 |
+
"""
|
35 |
+
cm = plt.get_cmap(cmap)
|
36 |
+
cam = cm(cam)[...,:3] # RGB [0,1]
|
37 |
+
cam = np.array(img, dtype=np.float32)/255. + cam
|
38 |
+
cam /= cam.max()
|
39 |
+
cam = np.uint8(cam*255)
|
40 |
+
return Image.fromarray(cam)
|
41 |
+
|
42 |
+
|
43 |
+
class HookedModel(object):
|
44 |
+
def __init__(self, model, feature_layer_name):
|
45 |
+
self.model = model
|
46 |
+
self.feature_trees = feature_layer_name.split('.')
|
47 |
+
|
48 |
+
def __call__(self, x):
|
49 |
+
x = feedforward(x, self.model, self.feature_trees)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
def feedforward(x, module, layer_names):
|
54 |
+
for name, submodule in module._modules.items():
|
55 |
+
# print(f'Forwarding {name} ...')
|
56 |
+
if name == layer_names[0]:
|
57 |
+
if len(layer_names) == 1: # leaf node reached
|
58 |
+
# print(f' Hook {name}')
|
59 |
+
x = submodule(x)
|
60 |
+
x.register_hook(save_gradients)
|
61 |
+
save_features(x)
|
62 |
+
else:
|
63 |
+
# print(f' Stepping into {name}:')
|
64 |
+
x = feedforward(x, submodule, layer_names[1:])
|
65 |
+
else:
|
66 |
+
x = submodule(x)
|
67 |
+
if name == 'avgpool': # specific for resnet50
|
68 |
+
x = x.view(x.size(0), -1)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
basket = dict(grads=[], feature_maps=[]) # global variable to hold the gradients and output features of the layers of interest
|
73 |
+
|
74 |
+
def empty_basket():
|
75 |
+
basket = dict(grads=[], feature_maps=[])
|
76 |
+
|
77 |
+
def save_gradients(grad):
|
78 |
+
basket['grads'].append(grad)
|
79 |
+
|
80 |
+
def save_features(feat):
|
81 |
+
basket['feature_maps'].append(feat)
|
82 |
+
|
83 |
+
|
84 |
+
class GradCam(object):
|
85 |
+
def __init__(self, model, feature_layer_name, use_cuda=True):
|
86 |
+
self.model = model
|
87 |
+
self.hooked_model = HookedModel(model, feature_layer_name)
|
88 |
+
self.cuda = use_cuda
|
89 |
+
if self.cuda:
|
90 |
+
self.model = model.cuda()
|
91 |
+
self.model.eval()
|
92 |
+
|
93 |
+
def __call__(self, x, target, act=None):
|
94 |
+
empty_basket()
|
95 |
+
target = torch.as_tensor(target, dtype=torch.float)
|
96 |
+
if self.cuda:
|
97 |
+
x = x.cuda()
|
98 |
+
target = target.cuda()
|
99 |
+
z = self.hooked_model(x)
|
100 |
+
if act is not None:
|
101 |
+
z = act(z)
|
102 |
+
criteria = F.cosine_similarity(z, target)
|
103 |
+
self.model.zero_grad()
|
104 |
+
criteria.backward(retain_graph=True)
|
105 |
+
gradients = [grad.cpu().data.numpy() for grad in basket['grads'][::-1]] # gradients appear in reversed order
|
106 |
+
feature_maps = [feat.cpu().data.numpy() for feat in basket['feature_maps']]
|
107 |
+
cams = []
|
108 |
+
for feat, grad in zip(feature_maps, gradients):
|
109 |
+
# feat and grad have shape (1, C, H, W)
|
110 |
+
weight = np.mean(grad, axis=(2,3), keepdims=True)[0] # (C,1,1)
|
111 |
+
cam = np.sum(weight * feat[0], axis=0) # (H,w)
|
112 |
+
cam = cv2.resize(cam, x.shape[2:])
|
113 |
+
cam = cam - np.min(cam)
|
114 |
+
cam = cam / (np.max(cam) + np.finfo(np.float32).eps)
|
115 |
+
cams.append(cam)
|
116 |
+
cams = np.array(cams).mean(axis=0) # (H,W)
|
117 |
+
return cams
|
118 |
+
|
119 |
+
|
120 |
+
def gradcam_demo():
|
121 |
+
from torchvision import transforms
|
122 |
+
model = torchvision.models.resnet50(pretrained=True)
|
123 |
+
model.eval()
|
124 |
+
gradcam = GradCam(model, 'layer4.2', True)
|
125 |
+
tform = [
|
126 |
+
transforms.Resize((224, 224)),
|
127 |
+
# transforms.CenterCrop(224),
|
128 |
+
transforms.ToTensor(),
|
129 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
130 |
+
]
|
131 |
+
preprocess = transforms.Compose(tform)
|
132 |
+
im0 = Image.open('/mnt/fast/nobackup/users/tb0035/projects/diffsteg/ControlNet/examples/catdog.jpg').convert('RGB')
|
133 |
+
im = preprocess(im0).unsqueeze(0)
|
134 |
+
target = np.zeros((1,1000), dtype=np.float32)
|
135 |
+
target[0, 285] = 1 # cat
|
136 |
+
cam = gradcam(im, target)
|
137 |
+
|
138 |
+
im0 = tform[0](im0)
|
139 |
+
out = show_cam_on_image(im0, cam)
|
140 |
+
out.save('test.jpg')
|
141 |
+
print('done')
|
142 |
+
|
143 |
+
|
144 |
+
def make_target_vector(nclass, target_class_id):
|
145 |
+
out = np.zeros((1, nclass), dtype=np.float32)
|
146 |
+
out[0, target_class_id] = 1
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == '__main__':
|
152 |
+
gradcam_demo()
|
tools/helpers.py
ADDED
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Tue Jul 12 11:05:57 2016
|
4 |
+
some help functions to perform basic tasks
|
5 |
+
@author: tb00083
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import csv
|
10 |
+
import socket
|
11 |
+
import numpy as np
|
12 |
+
import json
|
13 |
+
import pickle # python3.x
|
14 |
+
import time
|
15 |
+
from datetime import timedelta, datetime
|
16 |
+
from typing import Any, List, Tuple, Union
|
17 |
+
import subprocess
|
18 |
+
import struct
|
19 |
+
import errno
|
20 |
+
from pprint import pprint
|
21 |
+
import glob
|
22 |
+
from threading import Thread
|
23 |
+
|
24 |
+
|
25 |
+
def welcome_message():
|
26 |
+
"""
|
27 |
+
get welcome message including hostname and command line arguments
|
28 |
+
"""
|
29 |
+
hostname = socket.gethostname()
|
30 |
+
all_args = ' '.join(sys.argv)
|
31 |
+
out_text = 'On server {}: {}\n'.format(hostname, all_args)
|
32 |
+
return out_text
|
33 |
+
|
34 |
+
|
35 |
+
class EasyDict(dict):
|
36 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
37 |
+
def __init__(self, dict_to_convert=None):
|
38 |
+
if dict_to_convert is not None:
|
39 |
+
for key, val in dict_to_convert.items():
|
40 |
+
self[key] = val
|
41 |
+
|
42 |
+
def __getattr__(self, name: str) -> Any:
|
43 |
+
try:
|
44 |
+
return self[name]
|
45 |
+
except KeyError:
|
46 |
+
raise AttributeError(name)
|
47 |
+
|
48 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
49 |
+
self[name] = value
|
50 |
+
|
51 |
+
def __delattr__(self, name: str) -> None:
|
52 |
+
del self[name]
|
53 |
+
|
54 |
+
|
55 |
+
def get_time_id_str():
|
56 |
+
"""
|
57 |
+
returns a string with DDHHM format, where M is the minutes cut to the tenths
|
58 |
+
"""
|
59 |
+
now = datetime.now()
|
60 |
+
time_str = "{:02d}{:02d}{:02d}".format(now.day, now.hour, now.minute)
|
61 |
+
time_str = time_str[:-1]
|
62 |
+
return time_str
|
63 |
+
|
64 |
+
|
65 |
+
def time_format(t):
|
66 |
+
m, s = divmod(t, 60)
|
67 |
+
h, m = divmod(m, 60)
|
68 |
+
m, h, s = int(m), int(h), int(s)
|
69 |
+
|
70 |
+
if m == 0 and h == 0:
|
71 |
+
return "{}s".format(s)
|
72 |
+
elif h == 0:
|
73 |
+
return "{}m{}s".format(m, s)
|
74 |
+
else:
|
75 |
+
return "{}h{}m{}s".format(h, m, s)
|
76 |
+
|
77 |
+
|
78 |
+
def get_all_files(dir_path, trim=0, extension=''):
|
79 |
+
"""
|
80 |
+
Recursively get list of all files in the given directory
|
81 |
+
trim = 1 : trim the dir_path from results, 0 otherwise
|
82 |
+
extension: get files with specific format
|
83 |
+
"""
|
84 |
+
file_paths = [] # List which will store all of the full filepaths.
|
85 |
+
|
86 |
+
# Walk the tree.
|
87 |
+
for root, directories, files in os.walk(dir_path):
|
88 |
+
for filename in files:
|
89 |
+
# Join the two strings in order to form the full filepath.
|
90 |
+
filepath = os.path.join(root, filename)
|
91 |
+
file_paths.append(filepath) # Add it to the list.
|
92 |
+
|
93 |
+
if trim == 1: # trim dir_path from results
|
94 |
+
if dir_path[-1] != os.sep:
|
95 |
+
dir_path += os.sep
|
96 |
+
trim_len = len(dir_path)
|
97 |
+
file_paths = [x[trim_len:] for x in file_paths]
|
98 |
+
|
99 |
+
if extension: # select only file with specific extension
|
100 |
+
extension = extension.lower()
|
101 |
+
tlen = len(extension)
|
102 |
+
file_paths = [x for x in file_paths if x[-tlen:] == extension]
|
103 |
+
|
104 |
+
return file_paths # Self-explanatory.
|
105 |
+
|
106 |
+
|
107 |
+
def get_all_dirs(dir_path, trim=0):
|
108 |
+
"""
|
109 |
+
Recursively get list of all directories in the given directory
|
110 |
+
excluding the '.' and '..' directories
|
111 |
+
trim = 1 : trim the dir_path from results, 0 otherwise
|
112 |
+
"""
|
113 |
+
out = []
|
114 |
+
# Walk the tree.
|
115 |
+
for root, directories, files in os.walk(dir_path):
|
116 |
+
for dirname in directories:
|
117 |
+
# Join the two strings in order to form the full filepath.
|
118 |
+
dir_full = os.path.join(root, dirname)
|
119 |
+
out.append(dir_full) # Add it to the list.
|
120 |
+
|
121 |
+
if trim == 1: # trim dir_path from results
|
122 |
+
if dir_path[-1] != os.sep:
|
123 |
+
dir_path += os.sep
|
124 |
+
trim_len = len(dir_path)
|
125 |
+
out = [x[trim_len:] for x in out]
|
126 |
+
|
127 |
+
return out
|
128 |
+
|
129 |
+
|
130 |
+
def read_list(file_path, delimeter=' ', keep_original=True):
|
131 |
+
"""
|
132 |
+
read list column wise
|
133 |
+
deprecated, should use pandas instead
|
134 |
+
"""
|
135 |
+
out = []
|
136 |
+
with open(file_path, 'r') as f:
|
137 |
+
reader = csv.reader(f, delimiter=delimeter)
|
138 |
+
for row in reader:
|
139 |
+
out.append(row)
|
140 |
+
out = zip(*out)
|
141 |
+
|
142 |
+
if not keep_original:
|
143 |
+
for col in range(len(out)):
|
144 |
+
if out[col][0].isdigit(): # attempt to convert to numerical array
|
145 |
+
out[col] = np.array(out[col]).astype(np.int64)
|
146 |
+
|
147 |
+
return out
|
148 |
+
|
149 |
+
|
150 |
+
def save_pickle2(file_path, **kwargs):
|
151 |
+
"""
|
152 |
+
save variables to file (using pickle)
|
153 |
+
"""
|
154 |
+
# check if any variable is a dict
|
155 |
+
var_count = 0
|
156 |
+
for key in kwargs:
|
157 |
+
var_count += 1
|
158 |
+
if isinstance(kwargs[key], dict):
|
159 |
+
sys.stderr.write('Opps! Cannot write a dictionary into pickle')
|
160 |
+
sys.exit(1)
|
161 |
+
with open(file_path, 'wb') as f:
|
162 |
+
pickler = pickle.Pickler(f, -1)
|
163 |
+
pickler.dump(var_count)
|
164 |
+
for key in kwargs:
|
165 |
+
pickler.dump(key)
|
166 |
+
pickler.dump(kwargs[key])
|
167 |
+
|
168 |
+
|
169 |
+
def load_pickle2(file_path, varnum=0):
|
170 |
+
"""
|
171 |
+
load variables that previously saved using self.save()
|
172 |
+
varnum : number of variables u want to load (0 mean it will load all)
|
173 |
+
Note: if you are loading class instance(s), you must have it defined in advance
|
174 |
+
"""
|
175 |
+
with open(file_path, 'rb') as f:
|
176 |
+
pickler = pickle.Unpickler(f)
|
177 |
+
var_count = pickler.load()
|
178 |
+
if varnum:
|
179 |
+
var_count = min([var_count, varnum])
|
180 |
+
out = {}
|
181 |
+
for i in range(var_count):
|
182 |
+
key = pickler.load()
|
183 |
+
out[key] = pickler.load()
|
184 |
+
|
185 |
+
return out
|
186 |
+
|
187 |
+
|
188 |
+
def save_pickle(path, obj):
|
189 |
+
"""
|
190 |
+
simple method to save a picklable object
|
191 |
+
:param path: path to save
|
192 |
+
:param obj: a picklable object
|
193 |
+
:return: None
|
194 |
+
"""
|
195 |
+
with open(path, 'wb') as f:
|
196 |
+
pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)
|
197 |
+
|
198 |
+
|
199 |
+
def load_pickle(path):
|
200 |
+
"""
|
201 |
+
load a pickled object
|
202 |
+
:param path: .pkl path
|
203 |
+
:return: the pickled object
|
204 |
+
"""
|
205 |
+
with open(path, 'rb') as f:
|
206 |
+
return pickle.load(f)
|
207 |
+
|
208 |
+
|
209 |
+
def make_new_dir(dir_path, remove_existing=False, mode=511):
|
210 |
+
"""note: default mode in ubuntu is 511"""
|
211 |
+
if not os.path.exists(dir_path):
|
212 |
+
try:
|
213 |
+
if mode == 777:
|
214 |
+
oldmask = os.umask(000)
|
215 |
+
os.makedirs(dir_path, 0o777)
|
216 |
+
os.umask(oldmask)
|
217 |
+
else:
|
218 |
+
os.makedirs(dir_path, mode)
|
219 |
+
except OSError as exc: # Python >2.5
|
220 |
+
if exc.errno == errno.EEXIST and os.path.isdir(dir_path):
|
221 |
+
pass
|
222 |
+
else:
|
223 |
+
raise
|
224 |
+
if remove_existing:
|
225 |
+
for file_obj in os.listdir(dir_path):
|
226 |
+
file_path = os.path.join(dir_path, file_obj)
|
227 |
+
if os.path.isfile(file_path):
|
228 |
+
os.unlink(file_path)
|
229 |
+
|
230 |
+
|
231 |
+
def get_latest_file(root, pattern):
|
232 |
+
"""
|
233 |
+
get the latest file in a directory that match the provided pattern
|
234 |
+
useful for getting the last checkpoint
|
235 |
+
:param root: search directory
|
236 |
+
:param pattern: search pattern containing 1 wild card representing a number e.g. 'ckpt_*.tar'
|
237 |
+
:return: full path of the file with largest number in wild card, None if not found
|
238 |
+
"""
|
239 |
+
out = None
|
240 |
+
parts = pattern.split('*')
|
241 |
+
max_id = - np.inf
|
242 |
+
for path in glob.glob(os.path.join(root, pattern)):
|
243 |
+
id_ = os.path.basename(path)
|
244 |
+
for part in parts:
|
245 |
+
id_ = id_.replace(part, '')
|
246 |
+
try:
|
247 |
+
id_ = int(id_)
|
248 |
+
if id_ > max_id:
|
249 |
+
max_id = id_
|
250 |
+
out = path
|
251 |
+
except:
|
252 |
+
continue
|
253 |
+
return out
|
254 |
+
|
255 |
+
|
256 |
+
class Locker(object):
|
257 |
+
"""place a lock file in specified location
|
258 |
+
useful for distributed computing"""
|
259 |
+
|
260 |
+
def __init__(self, name='lock.txt', mode=511):
|
261 |
+
"""INPUT: name default file name to be created as a lock
|
262 |
+
mode if a directory has to be created, set its permission to mode"""
|
263 |
+
self.name = name
|
264 |
+
self.mode = mode
|
265 |
+
|
266 |
+
def lock(self, path):
|
267 |
+
make_new_dir(path, False, self.mode)
|
268 |
+
with open(os.path.join(path, self.name), 'w') as f:
|
269 |
+
f.write('progress')
|
270 |
+
|
271 |
+
def finish(self, path):
|
272 |
+
make_new_dir(path, False, self.mode)
|
273 |
+
with open(os.path.join(path, self.name), 'w') as f:
|
274 |
+
f.write('finish')
|
275 |
+
|
276 |
+
def customise(self, path, text):
|
277 |
+
make_new_dir(path, False, self.mode)
|
278 |
+
with open(os.path.join(path, self.name), 'w') as f:
|
279 |
+
f.write(text)
|
280 |
+
|
281 |
+
def is_locked(self, path):
|
282 |
+
out = False
|
283 |
+
check_path = os.path.join(path, self.name)
|
284 |
+
if os.path.exists(check_path):
|
285 |
+
text = open(check_path, 'r').readline().strip()
|
286 |
+
out = True if text == 'progress' else False
|
287 |
+
return out
|
288 |
+
|
289 |
+
def is_finished(self, path):
|
290 |
+
out = False
|
291 |
+
check_path = os.path.join(path, self.name)
|
292 |
+
if os.path.exists(check_path):
|
293 |
+
text = open(check_path, 'r').readline().strip()
|
294 |
+
out = True if text == 'finish' else False
|
295 |
+
return out
|
296 |
+
|
297 |
+
def is_locked_or_finished(self, path):
|
298 |
+
return self.is_locked(path) | self.is_finished(path)
|
299 |
+
|
300 |
+
def clean(self, path):
|
301 |
+
check_path = os.path.join(path, self.name)
|
302 |
+
if os.path.exists(check_path):
|
303 |
+
try:
|
304 |
+
os.remove(check_path)
|
305 |
+
except Exception as e:
|
306 |
+
print('Unable to remove %s: %s.' % (check_path, e))
|
307 |
+
|
308 |
+
|
309 |
+
class ProgressBar(object):
|
310 |
+
"""show progress"""
|
311 |
+
|
312 |
+
def __init__(self, total, increment=5):
|
313 |
+
self.total = total
|
314 |
+
self.point = self.total / 100.0
|
315 |
+
self.increment = increment
|
316 |
+
self.interval = int(self.total * self.increment / 100)
|
317 |
+
self.milestones = list(range(0, total, self.interval)) + [self.total, ]
|
318 |
+
self.id = 0
|
319 |
+
|
320 |
+
def show_progress(self, i):
|
321 |
+
if i >= self.milestones[self.id]:
|
322 |
+
while i >= self.milestones[self.id]:
|
323 |
+
self.id += 1
|
324 |
+
sys.stdout.write("\r[" + "=" * int(i / self.interval) +
|
325 |
+
" " * int((self.total - i) / self.interval) + "]" + str(int((i + 1) / self.point)) + "%")
|
326 |
+
sys.stdout.flush()
|
327 |
+
|
328 |
+
|
329 |
+
class Timer(object):
|
330 |
+
|
331 |
+
def __init__(self):
|
332 |
+
self.start_t = time.time()
|
333 |
+
self.last_t = self.start_t
|
334 |
+
|
335 |
+
def time(self, lap=False):
|
336 |
+
end_t = time.time()
|
337 |
+
if lap:
|
338 |
+
out = timedelta(seconds=int(end_t - self.last_t)) # count from last stop point
|
339 |
+
else:
|
340 |
+
out = timedelta(seconds=int(end_t - self.start_t)) # count from beginning
|
341 |
+
self.last_t = end_t
|
342 |
+
return out
|
343 |
+
|
344 |
+
|
345 |
+
class ExThread(Thread):
|
346 |
+
def run(self):
|
347 |
+
self.exc = None
|
348 |
+
try:
|
349 |
+
if hasattr(self, '_Thread__target'):
|
350 |
+
# Thread uses name mangling prior to Python 3.
|
351 |
+
self.ret = self._Thread__target(*self._Thread__args, **self._Thread__kwargs)
|
352 |
+
else:
|
353 |
+
self.ret = self._target(*self._args, **self._kwargs)
|
354 |
+
except BaseException as e:
|
355 |
+
self.exc = e
|
356 |
+
|
357 |
+
def join(self):
|
358 |
+
super(ExThread, self).join()
|
359 |
+
if self.exc:
|
360 |
+
raise RuntimeError('Exception in thread.') from self.exc
|
361 |
+
return self.ret
|
362 |
+
|
363 |
+
|
364 |
+
def get_gpu_free_mem():
|
365 |
+
"""return a list of free GPU memory"""
|
366 |
+
sp = subprocess.Popen(['nvidia-smi', '-q'], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
367 |
+
out_str = sp.communicate()
|
368 |
+
out_list = out_str[0].decode("utf-8") .split('\n')
|
369 |
+
|
370 |
+
out = []
|
371 |
+
for i in range(len(out_list)):
|
372 |
+
item = out_list[i]
|
373 |
+
if item.strip() == 'FB Memory Usage':
|
374 |
+
free_mem = int(out_list[i + 3].split(':')[1].strip().split(' ')[0])
|
375 |
+
out.append(free_mem)
|
376 |
+
return out
|
377 |
+
|
378 |
+
|
379 |
+
def float2hex(x):
|
380 |
+
"""
|
381 |
+
x: a vector
|
382 |
+
return: x in hex
|
383 |
+
"""
|
384 |
+
f = np.float32(x)
|
385 |
+
out = ''
|
386 |
+
if f.size == 1: # just a single number
|
387 |
+
f = [f, ]
|
388 |
+
for e in f:
|
389 |
+
h = hex(struct.unpack('<I', struct.pack('<f', e))[0])
|
390 |
+
out += h[2:].zfill(8)
|
391 |
+
return out
|
392 |
+
|
393 |
+
|
394 |
+
def hex2float(x):
|
395 |
+
"""
|
396 |
+
x: a string with len divided by 8
|
397 |
+
return x as array of float32
|
398 |
+
"""
|
399 |
+
assert len(x) % 8 == 0, 'Error! string len = {} not divided by 8'.format(len(x))
|
400 |
+
l = len(x) / 8
|
401 |
+
out = np.empty(l, dtype=np.float32)
|
402 |
+
x = [x[i:i + 8] for i in range(0, len(x), 8)]
|
403 |
+
for i, e in enumerate(x):
|
404 |
+
out[i] = struct.unpack('!f', e.decode('hex'))[0]
|
405 |
+
return out
|
406 |
+
|
407 |
+
|
408 |
+
def nice_print(inputs, stream=sys.stdout):
|
409 |
+
"""print a list of string to file stream"""
|
410 |
+
if type(inputs) is not list:
|
411 |
+
tstrings = inputs.split('\n')
|
412 |
+
pprint(tstrings, stream=stream)
|
413 |
+
else:
|
414 |
+
for string in inputs:
|
415 |
+
nice_print(string, stream=stream)
|
416 |
+
stream.flush()
|
tools/hparams.py
ADDED
@@ -0,0 +1,743 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 The Tensor2Tensor Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# source: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/hparam.py
|
16 |
+
# Forked with minor changes from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long
|
17 |
+
"""Hyperparameter values."""
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import division
|
20 |
+
from __future__ import print_function
|
21 |
+
|
22 |
+
import json
|
23 |
+
import numbers
|
24 |
+
import re
|
25 |
+
import six
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
# Define the regular expression for parsing a single clause of the input
|
29 |
+
# (delimited by commas). A legal clause looks like:
|
30 |
+
# <variable name>[<index>]? = <rhs>
|
31 |
+
# where <rhs> is either a single token or [] enclosed list of tokens.
|
32 |
+
# For example: "var[1] = a" or "x = [1,2,3]"
|
33 |
+
PARAM_RE = re.compile(r"""
|
34 |
+
(?P<name>[a-zA-Z][\w\.]*) # variable name: "var" or "x"
|
35 |
+
(\[\s*(?P<index>\d+)\s*\])? # (optional) index: "1" or None
|
36 |
+
\s*=\s*
|
37 |
+
((?P<val>[^,\[]*) # single value: "a" or None
|
38 |
+
|
|
39 |
+
\[(?P<vals>[^\]]*)\]) # list of values: None or "1,2,3"
|
40 |
+
($|,\s*)""", re.VERBOSE)
|
41 |
+
|
42 |
+
|
43 |
+
def copy_hparams(hparams):
|
44 |
+
"""Return a copy of an HParams instance."""
|
45 |
+
return HParams(**hparams.values())
|
46 |
+
|
47 |
+
|
48 |
+
def print_config(hps):
|
49 |
+
for key, val in six.iteritems(hps.values()):
|
50 |
+
print('%s = %s' % (key, str(val)))
|
51 |
+
|
52 |
+
|
53 |
+
def save_config(output_file, hps, verbose=True):
|
54 |
+
def convert(o): # json cannot serialize integer in np.int64 format
|
55 |
+
if isinstance(o, np.int64):
|
56 |
+
return int(o)
|
57 |
+
raise TypeError
|
58 |
+
if verbose:
|
59 |
+
print_config(hps)
|
60 |
+
with open(output_file, 'w') as f:
|
61 |
+
json.dump(hps.values(), f, indent=True, default=convert)
|
62 |
+
|
63 |
+
|
64 |
+
def load_config(hps, config_file, verbose=True):
|
65 |
+
"""
|
66 |
+
parse hparams from config file
|
67 |
+
:param hps: hparams object whose values to be updated
|
68 |
+
:param config_file: json config file
|
69 |
+
:param verbose: print out values
|
70 |
+
"""
|
71 |
+
try:
|
72 |
+
with open(config_file, 'r') as fin:
|
73 |
+
hps.parse_json(fin.read())
|
74 |
+
if verbose:
|
75 |
+
print_config(hps)
|
76 |
+
except Exception as e:
|
77 |
+
print('Error reading config file %s: %s.\nConfig will not be updated.' % (config_file, e))
|
78 |
+
# return hps
|
79 |
+
|
80 |
+
|
81 |
+
def _parse_fail(name, var_type, value, values):
|
82 |
+
"""Helper function for raising a value error for bad assignment."""
|
83 |
+
raise ValueError(
|
84 |
+
'Could not parse hparam \'%s\' of type \'%s\' with value \'%s\' in %s' %
|
85 |
+
(name, var_type.__name__, value, values))
|
86 |
+
|
87 |
+
|
88 |
+
def _reuse_fail(name, values):
|
89 |
+
"""Helper function for raising a value error for reuse of name."""
|
90 |
+
raise ValueError('Multiple assignments to variable \'%s\' in %s' % (name,
|
91 |
+
values))
|
92 |
+
|
93 |
+
|
94 |
+
def _process_scalar_value(name, parse_fn, var_type, m_dict, values,
|
95 |
+
results_dictionary):
|
96 |
+
"""Update results_dictionary with a scalar value.
|
97 |
+
|
98 |
+
Used to update the results_dictionary to be returned by parse_values when
|
99 |
+
encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".)
|
100 |
+
|
101 |
+
Mutates results_dictionary.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
name: Name of variable in assignment ("s" or "arr").
|
105 |
+
parse_fn: Function for parsing the actual value.
|
106 |
+
var_type: Type of named variable.
|
107 |
+
m_dict: Dictionary constructed from regex parsing.
|
108 |
+
m_dict['val']: RHS value (scalar)
|
109 |
+
m_dict['index']: List index value (or None)
|
110 |
+
values: Full expression being parsed
|
111 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
112 |
+
function.
|
113 |
+
|
114 |
+
Raises:
|
115 |
+
ValueError: If the name has already been used.
|
116 |
+
"""
|
117 |
+
try:
|
118 |
+
parsed_value = parse_fn(m_dict['val'])
|
119 |
+
except ValueError:
|
120 |
+
_parse_fail(name, var_type, m_dict['val'], values)
|
121 |
+
|
122 |
+
# If no index is provided
|
123 |
+
if not m_dict['index']:
|
124 |
+
if name in results_dictionary:
|
125 |
+
_reuse_fail(name, values)
|
126 |
+
results_dictionary[name] = parsed_value
|
127 |
+
else:
|
128 |
+
if name in results_dictionary:
|
129 |
+
# The name has already been used as a scalar, then it
|
130 |
+
# will be in this dictionary and map to a non-dictionary.
|
131 |
+
if not isinstance(results_dictionary.get(name), dict):
|
132 |
+
_reuse_fail(name, values)
|
133 |
+
else:
|
134 |
+
results_dictionary[name] = {}
|
135 |
+
|
136 |
+
index = int(m_dict['index'])
|
137 |
+
# Make sure the index position hasn't already been assigned a value.
|
138 |
+
if index in results_dictionary[name]:
|
139 |
+
_reuse_fail('{}[{}]'.format(name, index), values)
|
140 |
+
results_dictionary[name][index] = parsed_value
|
141 |
+
|
142 |
+
|
143 |
+
def _process_list_value(name, parse_fn, var_type, m_dict, values,
|
144 |
+
results_dictionary):
|
145 |
+
"""Update results_dictionary from a list of values.
|
146 |
+
|
147 |
+
Used to update results_dictionary to be returned by parse_values when
|
148 |
+
encountering a clause with a list RHS (e.g. "arr=[1,2,3]".)
|
149 |
+
|
150 |
+
Mutates results_dictionary.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
name: Name of variable in assignment ("arr").
|
154 |
+
parse_fn: Function for parsing individual values.
|
155 |
+
var_type: Type of named variable.
|
156 |
+
m_dict: Dictionary constructed from regex parsing.
|
157 |
+
m_dict['val']: RHS value (scalar)
|
158 |
+
values: Full expression being parsed
|
159 |
+
results_dictionary: The dictionary being updated for return by the parsing
|
160 |
+
function.
|
161 |
+
|
162 |
+
Raises:
|
163 |
+
ValueError: If the name has an index or the values cannot be parsed.
|
164 |
+
"""
|
165 |
+
if m_dict['index'] is not None:
|
166 |
+
raise ValueError('Assignment of a list to a list index.')
|
167 |
+
elements = filter(None, re.split('[ ,]', m_dict['vals']))
|
168 |
+
# Make sure the name hasn't already been assigned a value
|
169 |
+
if name in results_dictionary:
|
170 |
+
raise _reuse_fail(name, values)
|
171 |
+
try:
|
172 |
+
results_dictionary[name] = [parse_fn(e) for e in elements]
|
173 |
+
except ValueError:
|
174 |
+
_parse_fail(name, var_type, m_dict['vals'], values)
|
175 |
+
|
176 |
+
|
177 |
+
def _cast_to_type_if_compatible(name, param_type, value):
|
178 |
+
"""Cast hparam to the provided type, if compatible.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
name: Name of the hparam to be cast.
|
182 |
+
param_type: The type of the hparam.
|
183 |
+
value: The value to be cast, if compatible.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
The result of casting `value` to `param_type`.
|
187 |
+
|
188 |
+
Raises:
|
189 |
+
ValueError: If the type of `value` is not compatible with param_type.
|
190 |
+
* If `param_type` is a string type, but `value` is not.
|
191 |
+
* If `param_type` is a boolean, but `value` is not, or vice versa.
|
192 |
+
* If `param_type` is an integer type, but `value` is not.
|
193 |
+
* If `param_type` is a float type, but `value` is not a numeric type.
|
194 |
+
"""
|
195 |
+
fail_msg = (
|
196 |
+
"Could not cast hparam '%s' of type '%s' from value %r" %
|
197 |
+
(name, param_type, value))
|
198 |
+
|
199 |
+
# Some callers use None, for which we can't do any casting/checking. :(
|
200 |
+
if issubclass(param_type, type(None)):
|
201 |
+
return value
|
202 |
+
|
203 |
+
# Avoid converting a non-string type to a string.
|
204 |
+
if (issubclass(param_type, (six.string_types, six.binary_type)) and
|
205 |
+
not isinstance(value, (six.string_types, six.binary_type))):
|
206 |
+
raise ValueError(fail_msg)
|
207 |
+
|
208 |
+
# Avoid converting a number or string type to a boolean or vice versa.
|
209 |
+
if issubclass(param_type, bool) != isinstance(value, bool):
|
210 |
+
raise ValueError(fail_msg)
|
211 |
+
|
212 |
+
# Avoid converting float to an integer (the reverse is fine).
|
213 |
+
if (issubclass(param_type, numbers.Integral) and
|
214 |
+
not isinstance(value, numbers.Integral)):
|
215 |
+
raise ValueError(fail_msg)
|
216 |
+
|
217 |
+
# Avoid converting a non-numeric type to a numeric type.
|
218 |
+
if (issubclass(param_type, numbers.Number) and
|
219 |
+
not isinstance(value, numbers.Number)):
|
220 |
+
raise ValueError(fail_msg)
|
221 |
+
|
222 |
+
return param_type(value)
|
223 |
+
|
224 |
+
|
225 |
+
def parse_values(values, type_map, ignore_unknown=False):
|
226 |
+
"""Parses hyperparameter values from a string into a python map.
|
227 |
+
|
228 |
+
`values` is a string containing comma-separated `name=value` pairs.
|
229 |
+
For each pair, the value of the hyperparameter named `name` is set to
|
230 |
+
`value`.
|
231 |
+
|
232 |
+
If a hyperparameter name appears multiple times in `values`, a ValueError
|
233 |
+
is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2').
|
234 |
+
|
235 |
+
If a hyperparameter name in both an index assignment and scalar assignment,
|
236 |
+
a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1').
|
237 |
+
|
238 |
+
The hyperparameter name may contain '.' symbols, which will result in an
|
239 |
+
attribute name that is only accessible through the getattr and setattr
|
240 |
+
functions. (And must be first explicit added through add_hparam.)
|
241 |
+
|
242 |
+
WARNING: Use of '.' in your variable names is allowed, but is not well
|
243 |
+
supported and not recommended.
|
244 |
+
|
245 |
+
The `value` in `name=value` must follows the syntax according to the
|
246 |
+
type of the parameter:
|
247 |
+
|
248 |
+
* Scalar integer: A Python-parsable integer point value. E.g.: 1,
|
249 |
+
100, -12.
|
250 |
+
* Scalar float: A Python-parsable floating point value. E.g.: 1.0,
|
251 |
+
-.54e89.
|
252 |
+
* Boolean: Either true or false.
|
253 |
+
* Scalar string: A non-empty sequence of characters, excluding comma,
|
254 |
+
spaces, and square brackets. E.g.: foo, bar_1.
|
255 |
+
* List: A comma separated list of scalar values of the parameter type
|
256 |
+
enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low].
|
257 |
+
|
258 |
+
When index assignment is used, the corresponding type_map key should be the
|
259 |
+
list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not
|
260 |
+
"arr[1]").
|
261 |
+
|
262 |
+
Args:
|
263 |
+
values: String. Comma separated list of `name=value` pairs where
|
264 |
+
'value' must follow the syntax described above.
|
265 |
+
type_map: A dictionary mapping hyperparameter names to types. Note every
|
266 |
+
parameter name in values must be a key in type_map. The values must
|
267 |
+
conform to the types indicated, where a value V is said to conform to a
|
268 |
+
type T if either V has type T, or V is a list of elements of type T.
|
269 |
+
Hence, for a multidimensional parameter 'x' taking float values,
|
270 |
+
'x=[0.1,0.2]' will parse successfully if type_map['x'] = float.
|
271 |
+
ignore_unknown: Bool. Whether values that are missing a type in type_map
|
272 |
+
should be ignored. If set to True, a ValueError will not be raised for
|
273 |
+
unknown hyperparameter type.
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
A python map mapping each name to either:
|
277 |
+
* A scalar value.
|
278 |
+
* A list of scalar values.
|
279 |
+
* A dictionary mapping index numbers to scalar values.
|
280 |
+
(e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}")
|
281 |
+
|
282 |
+
Raises:
|
283 |
+
ValueError: If there is a problem with input.
|
284 |
+
* If `values` cannot be parsed.
|
285 |
+
* If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]').
|
286 |
+
* If the same rvalue is assigned two different values (e.g. 'a=1,a=2',
|
287 |
+
'a[1]=1,a[1]=2', or 'a=1,a=[1]')
|
288 |
+
"""
|
289 |
+
results_dictionary = {}
|
290 |
+
pos = 0
|
291 |
+
while pos < len(values):
|
292 |
+
m = PARAM_RE.match(values, pos)
|
293 |
+
if not m:
|
294 |
+
raise ValueError('Malformed hyperparameter value: %s' % values[pos:])
|
295 |
+
# Check that there is a comma between parameters and move past it.
|
296 |
+
pos = m.end()
|
297 |
+
# Parse the values.
|
298 |
+
m_dict = m.groupdict()
|
299 |
+
name = m_dict['name']
|
300 |
+
if name not in type_map:
|
301 |
+
if ignore_unknown:
|
302 |
+
continue
|
303 |
+
raise ValueError('Unknown hyperparameter type for %s' % name)
|
304 |
+
type_ = type_map[name]
|
305 |
+
|
306 |
+
# Set up correct parsing function (depending on whether type_ is a bool)
|
307 |
+
if type_ == bool:
|
308 |
+
def parse_bool(value):
|
309 |
+
if value in ['true', 'True']:
|
310 |
+
return True
|
311 |
+
elif value in ['false', 'False']:
|
312 |
+
return False
|
313 |
+
else:
|
314 |
+
try:
|
315 |
+
return bool(int(value))
|
316 |
+
except ValueError:
|
317 |
+
_parse_fail(name, type_, value, values)
|
318 |
+
|
319 |
+
parse = parse_bool
|
320 |
+
else:
|
321 |
+
parse = type_
|
322 |
+
|
323 |
+
# If a singe value is provided
|
324 |
+
if m_dict['val'] is not None:
|
325 |
+
_process_scalar_value(name, parse, type_, m_dict, values,
|
326 |
+
results_dictionary)
|
327 |
+
|
328 |
+
# If the assigned value is a list:
|
329 |
+
elif m_dict['vals'] is not None:
|
330 |
+
_process_list_value(name, parse, type_, m_dict, values,
|
331 |
+
results_dictionary)
|
332 |
+
|
333 |
+
else: # Not assigned a list or value
|
334 |
+
_parse_fail(name, type_, '', values)
|
335 |
+
|
336 |
+
return results_dictionary
|
337 |
+
|
338 |
+
|
339 |
+
class HParams(object):
|
340 |
+
"""Class to hold a set of hyperparameters as name-value pairs.
|
341 |
+
|
342 |
+
A `HParams` object holds hyperparameters used to build and train a model,
|
343 |
+
such as the number of hidden units in a neural net layer or the learning rate
|
344 |
+
to use when training.
|
345 |
+
|
346 |
+
You first create a `HParams` object by specifying the names and values of the
|
347 |
+
hyperparameters.
|
348 |
+
|
349 |
+
To make them easily accessible the parameter names are added as direct
|
350 |
+
attributes of the class. A typical usage is as follows:
|
351 |
+
|
352 |
+
```python
|
353 |
+
# Create a HParams object specifying names and values of the model
|
354 |
+
# hyperparameters:
|
355 |
+
hparams = HParams(learning_rate=0.1, num_hidden_units=100)
|
356 |
+
|
357 |
+
# The hyperparameter are available as attributes of the HParams object:
|
358 |
+
hparams.learning_rate ==> 0.1
|
359 |
+
hparams.num_hidden_units ==> 100
|
360 |
+
```
|
361 |
+
|
362 |
+
Hyperparameters have type, which is inferred from the type of their value
|
363 |
+
passed at construction type. The currently supported types are: integer,
|
364 |
+
float, boolean, string, and list of integer, float, boolean, or string.
|
365 |
+
|
366 |
+
You can override hyperparameter values by calling the
|
367 |
+
[`parse()`](#HParams.parse) method, passing a string of comma separated
|
368 |
+
`name=value` pairs. This is intended to make it possible to override
|
369 |
+
any hyperparameter values from a single command-line flag to which
|
370 |
+
the user passes 'hyper-param=value' pairs. It avoids having to define
|
371 |
+
one flag for each hyperparameter.
|
372 |
+
|
373 |
+
The syntax expected for each value depends on the type of the parameter.
|
374 |
+
See `parse()` for a description of the syntax.
|
375 |
+
|
376 |
+
Example:
|
377 |
+
|
378 |
+
```python
|
379 |
+
# Define a command line flag to pass name=value pairs.
|
380 |
+
# For example using argparse:
|
381 |
+
import argparse
|
382 |
+
parser = argparse.ArgumentParser(description='Train my model.')
|
383 |
+
parser.add_argument('--hparams', type=str,
|
384 |
+
help='Comma separated list of "name=value" pairs.')
|
385 |
+
args = parser.parse_args()
|
386 |
+
...
|
387 |
+
def my_program():
|
388 |
+
# Create a HParams object specifying the names and values of the
|
389 |
+
# model hyperparameters:
|
390 |
+
hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100,
|
391 |
+
activations=['relu', 'tanh'])
|
392 |
+
|
393 |
+
# Override hyperparameters values by parsing the command line
|
394 |
+
hparams.parse(args.hparams)
|
395 |
+
|
396 |
+
# If the user passed `--hparams=learning_rate=0.3` on the command line
|
397 |
+
# then 'hparams' has the following attributes:
|
398 |
+
hparams.learning_rate ==> 0.3
|
399 |
+
hparams.num_hidden_units ==> 100
|
400 |
+
hparams.activations ==> ['relu', 'tanh']
|
401 |
+
|
402 |
+
# If the hyperparameters are in json format use parse_json:
|
403 |
+
hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}')
|
404 |
+
```
|
405 |
+
"""
|
406 |
+
|
407 |
+
_HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks.
|
408 |
+
|
409 |
+
def __init__(self, model_structure=None, **kwargs):
|
410 |
+
"""Create an instance of `HParams` from keyword arguments.
|
411 |
+
|
412 |
+
The keyword arguments specify name-values pairs for the hyperparameters.
|
413 |
+
The parameter types are inferred from the type of the values passed.
|
414 |
+
|
415 |
+
The parameter names are added as attributes of `HParams` object, so they
|
416 |
+
can be accessed directly with the dot notation `hparams._name_`.
|
417 |
+
|
418 |
+
Example:
|
419 |
+
|
420 |
+
```python
|
421 |
+
# Define 3 hyperparameters: 'learning_rate' is a float parameter,
|
422 |
+
# 'num_hidden_units' an integer parameter, and 'activation' a string
|
423 |
+
# parameter.
|
424 |
+
hparams = tf.HParams(
|
425 |
+
learning_rate=0.1, num_hidden_units=100, activation='relu')
|
426 |
+
|
427 |
+
hparams.activation ==> 'relu'
|
428 |
+
```
|
429 |
+
|
430 |
+
Note that a few names are reserved and cannot be used as hyperparameter
|
431 |
+
names. If you use one of the reserved name the constructor raises a
|
432 |
+
`ValueError`.
|
433 |
+
|
434 |
+
Args:
|
435 |
+
model_structure: An instance of ModelStructure, defining the feature
|
436 |
+
crosses to be used in the Trial.
|
437 |
+
**kwargs: Key-value pairs where the key is the hyperparameter name and
|
438 |
+
the value is the value for the parameter.
|
439 |
+
|
440 |
+
Raises:
|
441 |
+
ValueError: If both `hparam_def` and initialization values are provided,
|
442 |
+
or if one of the arguments is invalid.
|
443 |
+
|
444 |
+
"""
|
445 |
+
# Register the hyperparameters and their type in _hparam_types.
|
446 |
+
# This simplifies the implementation of parse().
|
447 |
+
# _hparam_types maps the parameter name to a tuple (type, bool).
|
448 |
+
# The type value is the type of the parameter for scalar hyperparameters,
|
449 |
+
# or the type of the list elements for multidimensional hyperparameters.
|
450 |
+
# The bool value is True if the value is a list, False otherwise.
|
451 |
+
self._hparam_types = {}
|
452 |
+
self._model_structure = model_structure
|
453 |
+
for name, value in six.iteritems(kwargs):
|
454 |
+
self.add_hparam(name, value)
|
455 |
+
|
456 |
+
def __add__(self, other):
|
457 |
+
"""
|
458 |
+
addition operation keeping key order
|
459 |
+
"""
|
460 |
+
out = HParams()
|
461 |
+
for key in self._hparam_types.keys():
|
462 |
+
out.add_hparam(key, getattr(self, key))
|
463 |
+
for key in other._hparam_types.keys():
|
464 |
+
if getattr(out, key, None) is None: # add new param
|
465 |
+
out.add_hparam(key, getattr(other, key))
|
466 |
+
else: # update existing param
|
467 |
+
out.set_hparam(key, getattr(other, key))
|
468 |
+
return out
|
469 |
+
|
470 |
+
def __str__(self):
|
471 |
+
s = 'HParams(\n'
|
472 |
+
for key, val in six.iteritems(self.values()):
|
473 |
+
s += f'\t{key} = {val}\n'
|
474 |
+
# print('%s = %s' % (key, str(val)))
|
475 |
+
s += ')'
|
476 |
+
return s
|
477 |
+
|
478 |
+
def __repr__(self):
|
479 |
+
return self.__str__()
|
480 |
+
|
481 |
+
def add_hparam(self, name, value):
|
482 |
+
"""Adds {name, value} pair to hyperparameters.
|
483 |
+
|
484 |
+
Args:
|
485 |
+
name: Name of the hyperparameter.
|
486 |
+
value: Value of the hyperparameter. Can be one of the following types:
|
487 |
+
int, float, string, int list, float list, or string list.
|
488 |
+
|
489 |
+
Raises:
|
490 |
+
ValueError: if one of the arguments is invalid.
|
491 |
+
"""
|
492 |
+
# Keys in kwargs are unique, but 'name' could the name of a pre-existing
|
493 |
+
# attribute of this object. In that case we refuse to use it as a
|
494 |
+
# hyperparameter name.
|
495 |
+
if getattr(self, name, None) is not None:
|
496 |
+
raise ValueError('Hyperparameter name is reserved: %s' % name)
|
497 |
+
if isinstance(value, (list, tuple)):
|
498 |
+
if not value:
|
499 |
+
raise ValueError(
|
500 |
+
'Multi-valued hyperparameters cannot be empty: %s' % name)
|
501 |
+
self._hparam_types[name] = (type(value[0]), True)
|
502 |
+
else:
|
503 |
+
self._hparam_types[name] = (type(value), False)
|
504 |
+
setattr(self, name, value)
|
505 |
+
|
506 |
+
def set_hparam(self, name, value):
|
507 |
+
"""Set the value of an existing hyperparameter.
|
508 |
+
|
509 |
+
This function verifies that the type of the value matches the type of the
|
510 |
+
existing hyperparameter.
|
511 |
+
|
512 |
+
Args:
|
513 |
+
name: Name of the hyperparameter.
|
514 |
+
value: New value of the hyperparameter.
|
515 |
+
|
516 |
+
Raises:
|
517 |
+
KeyError: If the hyperparameter doesn't exist.
|
518 |
+
ValueError: If there is a type mismatch.
|
519 |
+
"""
|
520 |
+
param_type, is_list = self._hparam_types[name]
|
521 |
+
if isinstance(value, list):
|
522 |
+
if not is_list:
|
523 |
+
raise ValueError(
|
524 |
+
'Must not pass a list for single-valued parameter: %s' % name)
|
525 |
+
setattr(self, name, [
|
526 |
+
_cast_to_type_if_compatible(name, param_type, v) for v in value])
|
527 |
+
else:
|
528 |
+
if is_list:
|
529 |
+
raise ValueError(
|
530 |
+
'Must pass a list for multi-valued parameter: %s.' % name)
|
531 |
+
setattr(self, name, _cast_to_type_if_compatible(name, param_type, value))
|
532 |
+
|
533 |
+
def del_hparam(self, name):
|
534 |
+
"""Removes the hyperparameter with key 'name'.
|
535 |
+
|
536 |
+
Does nothing if it isn't present.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
name: Name of the hyperparameter.
|
540 |
+
"""
|
541 |
+
if hasattr(self, name):
|
542 |
+
delattr(self, name)
|
543 |
+
del self._hparam_types[name]
|
544 |
+
|
545 |
+
def parse(self, values):
|
546 |
+
"""Override existing hyperparameter values, parsing new values from a string.
|
547 |
+
|
548 |
+
See parse_values for more detail on the allowed format for values.
|
549 |
+
|
550 |
+
Args:
|
551 |
+
values: String. Comma separated list of `name=value` pairs where 'value'
|
552 |
+
must follow the syntax described above.
|
553 |
+
|
554 |
+
Returns:
|
555 |
+
The `HParams` instance.
|
556 |
+
|
557 |
+
Raises:
|
558 |
+
ValueError: If `values` cannot be parsed or a hyperparameter in `values`
|
559 |
+
doesn't exist.
|
560 |
+
"""
|
561 |
+
type_map = {}
|
562 |
+
for name, t in self._hparam_types.items():
|
563 |
+
param_type, _ = t
|
564 |
+
type_map[name] = param_type
|
565 |
+
|
566 |
+
values_map = parse_values(values, type_map)
|
567 |
+
return self.override_from_dict(values_map)
|
568 |
+
|
569 |
+
def override_from_dict(self, values_dict):
|
570 |
+
"""Override existing hyperparameter values, parsing new values from a dictionary.
|
571 |
+
|
572 |
+
Args:
|
573 |
+
values_dict: Dictionary of name:value pairs.
|
574 |
+
|
575 |
+
Returns:
|
576 |
+
The `HParams` instance.
|
577 |
+
|
578 |
+
Raises:
|
579 |
+
KeyError: If a hyperparameter in `values_dict` doesn't exist.
|
580 |
+
ValueError: If `values_dict` cannot be parsed.
|
581 |
+
"""
|
582 |
+
for name, value in values_dict.items():
|
583 |
+
self.set_hparam(name, value)
|
584 |
+
return self
|
585 |
+
|
586 |
+
def set_model_structure(self, model_structure):
|
587 |
+
self._model_structure = model_structure
|
588 |
+
|
589 |
+
def get_model_structure(self):
|
590 |
+
return self._model_structure
|
591 |
+
|
592 |
+
def to_json(self, indent=None, separators=None, sort_keys=False):
|
593 |
+
"""Serializes the hyperparameters into JSON.
|
594 |
+
|
595 |
+
Args:
|
596 |
+
indent: If a non-negative integer, JSON array elements and object members
|
597 |
+
will be pretty-printed with that indent level. An indent level of 0, or
|
598 |
+
negative, will only insert newlines. `None` (the default) selects the
|
599 |
+
most compact representation.
|
600 |
+
separators: Optional `(item_separator, key_separator)` tuple. Default is
|
601 |
+
`(', ', ': ')`.
|
602 |
+
sort_keys: If `True`, the output dictionaries will be sorted by key.
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
A JSON string.
|
606 |
+
"""
|
607 |
+
def remove_callables(x):
|
608 |
+
"""Omit callable elements from input with arbitrary nesting."""
|
609 |
+
if isinstance(x, dict):
|
610 |
+
return {k: remove_callables(v) for k, v in six.iteritems(x)
|
611 |
+
if not callable(v)}
|
612 |
+
elif isinstance(x, list):
|
613 |
+
return [remove_callables(i) for i in x if not callable(i)]
|
614 |
+
return x
|
615 |
+
return json.dumps(
|
616 |
+
remove_callables(self.values()),
|
617 |
+
indent=indent,
|
618 |
+
separators=separators,
|
619 |
+
sort_keys=sort_keys)
|
620 |
+
|
621 |
+
def parse_json(self, values_json):
|
622 |
+
"""Override existing hyperparameter values, parsing new values from a json object.
|
623 |
+
|
624 |
+
Args:
|
625 |
+
values_json: String containing a json object of name:value pairs.
|
626 |
+
|
627 |
+
Returns:
|
628 |
+
The `HParams` instance.
|
629 |
+
|
630 |
+
Raises:
|
631 |
+
KeyError: If a hyperparameter in `values_json` doesn't exist.
|
632 |
+
ValueError: If `values_json` cannot be parsed.
|
633 |
+
"""
|
634 |
+
values_map = json.loads(values_json)
|
635 |
+
return self.override_from_dict(values_map)
|
636 |
+
|
637 |
+
def values(self):
|
638 |
+
"""Return the hyperparameter values as a Python dictionary.
|
639 |
+
|
640 |
+
Returns:
|
641 |
+
A dictionary with hyperparameter names as keys. The values are the
|
642 |
+
hyperparameter values.
|
643 |
+
"""
|
644 |
+
return {n: getattr(self, n) for n in self._hparam_types.keys()}
|
645 |
+
|
646 |
+
def get(self, key, default=None):
|
647 |
+
"""Returns the value of `key` if it exists, else `default`."""
|
648 |
+
if key in self._hparam_types:
|
649 |
+
# Ensure that default is compatible with the parameter type.
|
650 |
+
if default is not None:
|
651 |
+
param_type, is_param_list = self._hparam_types[key]
|
652 |
+
type_str = 'list<%s>' % param_type if is_param_list else str(param_type)
|
653 |
+
fail_msg = ("Hparam '%s' of type '%s' is incompatible with "
|
654 |
+
'default=%s' % (key, type_str, default))
|
655 |
+
|
656 |
+
is_default_list = isinstance(default, list)
|
657 |
+
if is_param_list != is_default_list:
|
658 |
+
raise ValueError(fail_msg)
|
659 |
+
|
660 |
+
try:
|
661 |
+
if is_default_list:
|
662 |
+
for value in default:
|
663 |
+
_cast_to_type_if_compatible(key, param_type, value)
|
664 |
+
else:
|
665 |
+
_cast_to_type_if_compatible(key, param_type, default)
|
666 |
+
except ValueError as e:
|
667 |
+
raise ValueError('%s. %s' % (fail_msg, e))
|
668 |
+
|
669 |
+
return getattr(self, key)
|
670 |
+
|
671 |
+
return default
|
672 |
+
|
673 |
+
def __contains__(self, key):
|
674 |
+
return key in self._hparam_types
|
675 |
+
|
676 |
+
@staticmethod
|
677 |
+
def _get_kind_name(param_type, is_list):
|
678 |
+
"""Returns the field name given parameter type and is_list.
|
679 |
+
|
680 |
+
Args:
|
681 |
+
param_type: Data type of the hparam.
|
682 |
+
is_list: Whether this is a list.
|
683 |
+
|
684 |
+
Returns:
|
685 |
+
A string representation of the field name.
|
686 |
+
|
687 |
+
Raises:
|
688 |
+
ValueError: If parameter type is not recognized.
|
689 |
+
"""
|
690 |
+
if issubclass(param_type, bool):
|
691 |
+
# This check must happen before issubclass(param_type, six.integer_types),
|
692 |
+
# since Python considers bool to be a subclass of int.
|
693 |
+
typename = 'bool'
|
694 |
+
elif issubclass(param_type, six.integer_types):
|
695 |
+
# Setting 'int' and 'long' types to be 'int64' to ensure the type is
|
696 |
+
# compatible with both Python2 and Python3.
|
697 |
+
typename = 'int64'
|
698 |
+
elif issubclass(param_type, (six.string_types, six.binary_type)):
|
699 |
+
# Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is
|
700 |
+
# compatible with both Python2 and Python3.
|
701 |
+
typename = 'bytes'
|
702 |
+
elif issubclass(param_type, float):
|
703 |
+
typename = 'float'
|
704 |
+
else:
|
705 |
+
raise ValueError('Unsupported parameter type: %s' % str(param_type))
|
706 |
+
|
707 |
+
suffix = 'list' if is_list else 'value'
|
708 |
+
return '_'.join([typename, suffix])
|
709 |
+
|
710 |
+
@staticmethod
|
711 |
+
def save_config(self, output_file, verbose=True):
|
712 |
+
def convert(o): # json cannot serialize integer in np.int64 format
|
713 |
+
if isinstance(o, np.int64):
|
714 |
+
return int(o)
|
715 |
+
raise TypeError
|
716 |
+
if verbose:
|
717 |
+
print(self)
|
718 |
+
with open(output_file, 'w') as f:
|
719 |
+
json.dump(self.values(), f, indent=True, default=convert)
|
720 |
+
|
721 |
+
@staticmethod
|
722 |
+
def load_config(config_file, verbose=True):
|
723 |
+
"""
|
724 |
+
parse hparams from config file
|
725 |
+
:param config_file: json config file
|
726 |
+
:param verbose: print out values
|
727 |
+
"""
|
728 |
+
try:
|
729 |
+
with open(config_file, 'r') as fin:
|
730 |
+
json_dict = json.loads(fin.read())
|
731 |
+
hps = HParams(**json_dict)
|
732 |
+
if verbose:
|
733 |
+
print_config(hps)
|
734 |
+
except Exception as e:
|
735 |
+
print('Error reading config file %s: %s.\nConfig will not be updated.' % (config_file, e))
|
736 |
+
return hps
|
737 |
+
|
738 |
+
@staticmethod
|
739 |
+
def clone(self):
|
740 |
+
"""
|
741 |
+
return a deep copy of this object
|
742 |
+
"""
|
743 |
+
return HParams(**self.values)
|
tools/image_dataset.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
imagefolder loader
|
5 |
+
inspired from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
|
6 |
+
@author: Tu Bui @surrey.ac.uk
|
7 |
+
"""
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import io
|
14 |
+
import time
|
15 |
+
import pandas as pd
|
16 |
+
import numpy as np
|
17 |
+
import random
|
18 |
+
from PIL import Image
|
19 |
+
from typing import Any, Callable, List, Optional, Tuple
|
20 |
+
import torch
|
21 |
+
from torchvision import transforms
|
22 |
+
from .base_lmdb import PILlmdb, ArrayDatabase
|
23 |
+
# from . import debug
|
24 |
+
|
25 |
+
|
26 |
+
def worker_init_fn(worker_id):
|
27 |
+
# to be passed to torch.utils.data.DataLoader to fix the
|
28 |
+
# random seed issue with numpy in multi-worker settings
|
29 |
+
torch_seed = torch.initial_seed()
|
30 |
+
random.seed(torch_seed + worker_id)
|
31 |
+
if torch_seed >= 2**30: # make sure torch_seed + workder_id < 2**32
|
32 |
+
torch_seed = torch_seed % 2**30
|
33 |
+
np.random.seed(torch_seed + worker_id)
|
34 |
+
|
35 |
+
|
36 |
+
def pil_loader(path: str) -> Image.Image:
|
37 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
38 |
+
with open(path, 'rb') as f:
|
39 |
+
img = Image.open(f)
|
40 |
+
return img.convert('RGB')
|
41 |
+
|
42 |
+
|
43 |
+
def dataset_wrapper(data_dir, data_list, **kwargs):
|
44 |
+
if os.path.exists(os.path.join(data_dir, 'data.mdb')):
|
45 |
+
return ImageDataset(data_dir, data_list, **kwargs)
|
46 |
+
else:
|
47 |
+
return ImageFolder(data_dir, data_list, **kwargs)
|
48 |
+
|
49 |
+
|
50 |
+
class ImageFolder(torch.utils.data.Dataset):
|
51 |
+
_repr_indent = 4
|
52 |
+
def __init__(self, data_dir, data_list, secret_len=100, resize=256, transform=None, **kwargs):
|
53 |
+
super().__init__()
|
54 |
+
self.transform = transforms.RandomResizedCrop((resize, resize), scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333)) if transform is None else transform
|
55 |
+
self.build_data(data_dir, data_list, **kwargs)
|
56 |
+
self.kwargs = kwargs
|
57 |
+
self.secret_len = secret_len
|
58 |
+
|
59 |
+
def build_data(self, data_dir, data_list, **kwargs):
|
60 |
+
self.data_dir = data_dir
|
61 |
+
if isinstance(data_list, list):
|
62 |
+
self.data_list = data_list
|
63 |
+
elif isinstance(data_list, str):
|
64 |
+
self.data_list = pd.read_csv(data_list)['path'].tolist()
|
65 |
+
elif isinstance(data_list, pd.DataFrame):
|
66 |
+
self.data_list = data_list['path'].tolist()
|
67 |
+
else:
|
68 |
+
raise ValueError('data_list must be a list, str or pd.DataFrame')
|
69 |
+
self.N = len(self.data_list)
|
70 |
+
|
71 |
+
def __getitem__(self, index):
|
72 |
+
path = self.data_list[index]
|
73 |
+
img = pil_loader(os.path.join(self.data_dir, path))
|
74 |
+
img = self.transform(img)
|
75 |
+
img = np.array(img, dtype=np.float32)/127.5-1. # [-1, 1]
|
76 |
+
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
|
77 |
+
return {'image': img, 'secret': secret} # {'img': x, 'index': index}
|
78 |
+
|
79 |
+
def __len__(self) -> int:
|
80 |
+
# raise NotImplementedError
|
81 |
+
return self.N
|
82 |
+
|
83 |
+
class ImageDataset(torch.utils.data.Dataset):
|
84 |
+
r"""
|
85 |
+
Customised Image Folder class for pytorch.
|
86 |
+
Accept lmdb and a csv list as the input.
|
87 |
+
Usage:
|
88 |
+
dataset = ImageDataset(img_dir, img_list)
|
89 |
+
dataset.set_transform(some_pytorch_transforms)
|
90 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True,
|
91 |
+
num_workers=4, worker_init_fn=worker_init_fn)
|
92 |
+
for x,y in loader:
|
93 |
+
# x and y is input and target (dict), the keys can be customised.
|
94 |
+
"""
|
95 |
+
_repr_indent = 4
|
96 |
+
def __init__(self, data_dir, data_list, secret_len=100, resize=None, transform=None, target_transform=None, **kwargs):
|
97 |
+
super().__init__()
|
98 |
+
if resize is not None:
|
99 |
+
self.resize = transforms.Resize((resize, resize))
|
100 |
+
self.set_transform(transform, target_transform)
|
101 |
+
self.build_data(data_dir, data_list, **kwargs)
|
102 |
+
self.secret_len = secret_len
|
103 |
+
self.kwargs = kwargs
|
104 |
+
|
105 |
+
def set_transform(self, transform, target_transform=None):
|
106 |
+
self.transform, self.target_transform = transform, target_transform
|
107 |
+
|
108 |
+
def build_data(self, data_dir, data_list, **kwargs):
|
109 |
+
"""
|
110 |
+
Args:
|
111 |
+
data_list (text file) must have at least 3 fields: id, path and label
|
112 |
+
|
113 |
+
This method must create an attribute self.samples containing ID, input and target samples; and another attribute N storing the dataset size
|
114 |
+
|
115 |
+
Optional attributes: classes (list of unique classes), group (useful for
|
116 |
+
metric learning)
|
117 |
+
"""
|
118 |
+
self.data_dir, self.list = data_dir, data_list
|
119 |
+
if ('dtype' in kwargs) and (kwargs['dtype'].lower() == 'array'):
|
120 |
+
data = ArrayDatabase(data_dir, data_list)
|
121 |
+
else:
|
122 |
+
data = PILlmdb(data_dir, data_list, **kwargs)
|
123 |
+
self.N = len(data)
|
124 |
+
self.classes = np.unique(data.labels)
|
125 |
+
self.samples = {'x': data, 'y': data.labels}
|
126 |
+
# assert isinstance(data_list, str) or isinstance(data_list, pd.DataFrame)
|
127 |
+
# df = pd.read_csv(data_list) if isinstance(data_list, str) else data_list
|
128 |
+
# assert 'id' in df and 'label' in df, f'[DATA] Error! {data_list} must contains "id" and "label".'
|
129 |
+
# ids = df['id'].tolist()
|
130 |
+
# labels = np.array(df['label'].tolist())
|
131 |
+
# data = PILlmdb(data_dir)
|
132 |
+
# assert set(ids).issubset(set(data.keys)) # ids should exist in lmdb
|
133 |
+
# self.N = len(ids)
|
134 |
+
# self.classes, inds = np.unique(labels, return_index=True)
|
135 |
+
# self.samples = {'id': ids, 'x': data, 'y': labels}
|
136 |
+
|
137 |
+
def set_ids(self, ids):
|
138 |
+
self.samples['x'].set_ids(ids)
|
139 |
+
self.samples['y'] = [self.samples['y'][i] for i in ids]
|
140 |
+
self.N = len(self.samples['x'])
|
141 |
+
|
142 |
+
def __getitem__(self, index: int) -> Any:
|
143 |
+
"""
|
144 |
+
Args:
|
145 |
+
index (int): Index
|
146 |
+
Returns:
|
147 |
+
dict: (x: sample, y: target, **kwargs)
|
148 |
+
"""
|
149 |
+
x, y = self.samples['x'][index], self.samples['y'][index]
|
150 |
+
if hasattr(self, 'resize'):
|
151 |
+
x = self.resize(x)
|
152 |
+
if self.transform is not None:
|
153 |
+
x = self.transform(x)
|
154 |
+
if self.target_transform is not None:
|
155 |
+
y = self.target_transform(y)
|
156 |
+
x = np.array(x, dtype=np.float32)/127.5-1.
|
157 |
+
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
|
158 |
+
return {'image': x, 'secret': secret} # {'img': x, 'index': index}
|
159 |
+
|
160 |
+
def __len__(self) -> int:
|
161 |
+
# raise NotImplementedError
|
162 |
+
return self.N
|
163 |
+
|
164 |
+
def __repr__(self) -> str:
|
165 |
+
head = "\nDataset " + self.__class__.__name__
|
166 |
+
body = ["Number of datapoints: {}".format(self.__len__())]
|
167 |
+
if hasattr(self, 'data_dir') and self.data_dir is not None:
|
168 |
+
body.append("data_dir location: {}".format(self.data_dir))
|
169 |
+
if hasattr(self, 'kwargs'):
|
170 |
+
body.append(f'kwargs: {self.kwargs}')
|
171 |
+
body += self.extra_repr().splitlines()
|
172 |
+
if hasattr(self, "transform") and self.transform is not None:
|
173 |
+
body += [repr(self.transform)]
|
174 |
+
lines = [head] + [" " * self._repr_indent + line for line in body]
|
175 |
+
return '\n'.join(lines)
|
176 |
+
|
177 |
+
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
178 |
+
lines = transform.__repr__().splitlines()
|
179 |
+
return (["{}{}".format(head, lines[0])] +
|
180 |
+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
|
181 |
+
|
182 |
+
def extra_repr(self) -> str:
|
183 |
+
return ""
|
184 |
+
|
tools/image_dataset_generic.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
imagefolder loader
|
5 |
+
inspired from https://github.com/adambielski/siamese-triplet/blob/master/datasets.py
|
6 |
+
@author: Tu Bui @surrey.ac.uk
|
7 |
+
"""
|
8 |
+
from __future__ import absolute_import
|
9 |
+
from __future__ import division
|
10 |
+
from __future__ import print_function
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import io
|
14 |
+
import time
|
15 |
+
import pandas as pd
|
16 |
+
import numpy as np
|
17 |
+
import random
|
18 |
+
from PIL import Image
|
19 |
+
from typing import Any, Callable, List, Optional, Tuple
|
20 |
+
import torch
|
21 |
+
from .base_lmdb import PILlmdb, ArrayDatabase
|
22 |
+
from torchvision import transforms
|
23 |
+
# from . import debug
|
24 |
+
|
25 |
+
|
26 |
+
def worker_init_fn(worker_id):
|
27 |
+
# to be passed to torch.utils.data.DataLoader to fix the
|
28 |
+
# random seed issue with numpy in multi-worker settings
|
29 |
+
torch_seed = torch.initial_seed()
|
30 |
+
random.seed(torch_seed + worker_id)
|
31 |
+
if torch_seed >= 2**30: # make sure torch_seed + workder_id < 2**32
|
32 |
+
torch_seed = torch_seed % 2**30
|
33 |
+
np.random.seed(torch_seed + worker_id)
|
34 |
+
|
35 |
+
|
36 |
+
def pil_loader(path: str) -> Image.Image:
|
37 |
+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
|
38 |
+
with open(path, 'rb') as f:
|
39 |
+
img = Image.open(f)
|
40 |
+
return img.convert('RGB')
|
41 |
+
|
42 |
+
|
43 |
+
class ImageDataset(torch.utils.data.Dataset):
|
44 |
+
r"""
|
45 |
+
Customised Image Folder class for pytorch.
|
46 |
+
Accept lmdb and a csv list as the input.
|
47 |
+
Usage:
|
48 |
+
dataset = ImageDataset(img_dir, img_list)
|
49 |
+
dataset.set_transform(some_pytorch_transforms)
|
50 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True,
|
51 |
+
num_workers=4, worker_init_fn=worker_init_fn)
|
52 |
+
for x,y in loader:
|
53 |
+
# x and y is input and target (dict), the keys can be customised.
|
54 |
+
"""
|
55 |
+
_repr_indent = 4
|
56 |
+
def __init__(self, data_dir, data_list, secret_len=100, transform=None, target_transform=None, **kwargs):
|
57 |
+
super().__init__()
|
58 |
+
self.set_transform(transform, target_transform)
|
59 |
+
self.build_data(data_dir, data_list, **kwargs)
|
60 |
+
self.secret_len = secret_len
|
61 |
+
self.kwargs = kwargs
|
62 |
+
|
63 |
+
def set_transform(self, transform, target_transform=None):
|
64 |
+
self.transform, self.target_transform = transform, target_transform
|
65 |
+
|
66 |
+
def build_data(self, data_dir, data_list, **kwargs):
|
67 |
+
"""
|
68 |
+
Args:
|
69 |
+
data_list (text file) must have at least 3 fields: id, path and label
|
70 |
+
|
71 |
+
This method must create an attribute self.samples containing ID, input and target samples; and another attribute N storing the dataset size
|
72 |
+
|
73 |
+
Optional attributes: classes (list of unique classes), group (useful for
|
74 |
+
metric learning)
|
75 |
+
"""
|
76 |
+
self.data_dir, self.list = data_dir, data_list
|
77 |
+
if ('dtype' in kwargs) and (kwargs['dtype'].lower() == 'array'):
|
78 |
+
data = ArrayDatabase(data_dir, data_list)
|
79 |
+
else:
|
80 |
+
data = PILlmdb(data_dir, data_list, **kwargs)
|
81 |
+
self.N = len(data)
|
82 |
+
self.classes = np.unique(data.labels)
|
83 |
+
self.samples = {'x': data, 'y': data.labels}
|
84 |
+
|
85 |
+
def __getitem__(self, index: int) -> Any:
|
86 |
+
"""
|
87 |
+
Args:
|
88 |
+
index (int): Index
|
89 |
+
Returns:
|
90 |
+
dict: (x: sample, y: target, **kwargs)
|
91 |
+
"""
|
92 |
+
x, y = self.samples['x'][index], self.samples['y'][index]
|
93 |
+
if self.transform is not None:
|
94 |
+
x = self.transform(x)
|
95 |
+
if self.target_transform is not None:
|
96 |
+
y = self.target_transform(y)
|
97 |
+
x = np.array(x, dtype=np.float32)/127.5-1.
|
98 |
+
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
|
99 |
+
return {'image': x, 'secret': secret} # {'img': x, 'index': index}
|
100 |
+
|
101 |
+
def __len__(self) -> int:
|
102 |
+
# raise NotImplementedError
|
103 |
+
return self.N
|
104 |
+
|
105 |
+
def __repr__(self) -> str:
|
106 |
+
head = "\nDataset " + self.__class__.__name__
|
107 |
+
body = ["Number of datapoints: {}".format(self.__len__())]
|
108 |
+
if hasattr(self, 'data_dir') and self.data_dir is not None:
|
109 |
+
body.append("data_dir location: {}".format(self.data_dir))
|
110 |
+
if hasattr(self, 'kwargs'):
|
111 |
+
body.append(f'kwargs: {self.kwargs}')
|
112 |
+
body += self.extra_repr().splitlines()
|
113 |
+
if hasattr(self, "transform") and self.transform is not None:
|
114 |
+
body += [repr(self.transform)]
|
115 |
+
lines = [head] + [" " * self._repr_indent + line for line in body]
|
116 |
+
return '\n'.join(lines)
|
117 |
+
|
118 |
+
def _format_transform_repr(self, transform: Callable, head: str) -> List[str]:
|
119 |
+
lines = transform.__repr__().splitlines()
|
120 |
+
return (["{}{}".format(head, lines[0])] +
|
121 |
+
["{}{}".format(" " * len(head), line) for line in lines[1:]])
|
122 |
+
|
123 |
+
def extra_repr(self) -> str:
|
124 |
+
return ""
|
125 |
+
|
126 |
+
class ImageFolder(torch.utils.data.Dataset):
|
127 |
+
_repr_indent = 4
|
128 |
+
def __init__(self, data_dir, data_list, secret_len=100, resize=256, transform=None, **kwargs):
|
129 |
+
super().__init__()
|
130 |
+
self.transform = transforms.Resize((resize, resize)) if transform is None else transform
|
131 |
+
self.build_data(data_dir, data_list, **kwargs)
|
132 |
+
self.kwargs = kwargs
|
133 |
+
self.secret_len = secret_len
|
134 |
+
|
135 |
+
def build_data(self, data_dir, data_list, **kwargs):
|
136 |
+
self.data_dir = data_dir
|
137 |
+
if isinstance(data_list, list):
|
138 |
+
self.data_list = data_list
|
139 |
+
elif isinstance(data_list, str):
|
140 |
+
self.data_list = pd.read_csv(data_list)['path'].tolist()
|
141 |
+
elif isinstance(data_list, pd.DataFrame):
|
142 |
+
self.data_list = data_list['path'].tolist()
|
143 |
+
else:
|
144 |
+
raise ValueError('data_list must be a list, str or pd.DataFrame')
|
145 |
+
self.N = len(self.data_list)
|
146 |
+
|
147 |
+
def __getitem__(self, index):
|
148 |
+
path = self.data_list[index]
|
149 |
+
img = pil_loader(os.path.join(self.data_dir, path))
|
150 |
+
img = self.transform(img)
|
151 |
+
img = np.array(img, dtype=np.float32)/127.5-1. # [-1, 1]
|
152 |
+
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2) # not used
|
153 |
+
return {'image': img, 'secret': secret} # {'img': x, 'index': index}
|
154 |
+
|
155 |
+
def __len__(self) -> int:
|
156 |
+
# raise NotImplementedError
|
157 |
+
return self.N
|
tools/image_tools.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
|
5 |
+
@author: Tu Bui @surrey.ac.uk
|
6 |
+
"""
|
7 |
+
from __future__ import absolute_import
|
8 |
+
from __future__ import division
|
9 |
+
from __future__ import print_function
|
10 |
+
from scipy import fftpack
|
11 |
+
import sys, os
|
12 |
+
from pathlib import Path
|
13 |
+
import numpy as np
|
14 |
+
import random
|
15 |
+
import glob
|
16 |
+
import json
|
17 |
+
import time
|
18 |
+
import importlib
|
19 |
+
import pandas as pd
|
20 |
+
from tqdm import tqdm
|
21 |
+
# from IPython.display import display
|
22 |
+
# import seaborn as sns
|
23 |
+
import matplotlib
|
24 |
+
# matplotlib.use('Agg') # headless run
|
25 |
+
import matplotlib.pyplot as plt
|
26 |
+
import matplotlib.patches as mpatches
|
27 |
+
from PIL import Image, ImageDraw, ImageFont
|
28 |
+
cmap = plt.get_cmap("tab10") # cmap as function
|
29 |
+
cmap = plt.rcParams['axes.prop_cycle'].by_key()['color'] # cmap
|
30 |
+
|
31 |
+
FONT = '/vol/research/tubui1/_base/utils/FreeSans.ttf'
|
32 |
+
|
33 |
+
# def imshow(im):
|
34 |
+
# if type(im) is np.ndarray:
|
35 |
+
# im = Image.fromarray(im)
|
36 |
+
# display(im)
|
37 |
+
|
38 |
+
def make_grid(array_list, gsize=(3,3)):
|
39 |
+
"""
|
40 |
+
make a grid image from a list of image array (RGB)
|
41 |
+
return: array RGB
|
42 |
+
"""
|
43 |
+
assert len(gsize)==2 and gsize[0]*gsize[1]==len(array_list)
|
44 |
+
h,w,c = array_list[0].shape
|
45 |
+
out = np.array(array_list).reshape(gsize[0], gsize[1], h, w, c).transpose(0, 2, 1, 3, 4).reshape(gsize[0]*h, gsize[1]*w, c)
|
46 |
+
return out
|
47 |
+
|
48 |
+
def collage(im_list, size=None, pad=0, color=255):
|
49 |
+
"""
|
50 |
+
generalised function of make_grid()
|
51 |
+
work on PIL/numpy images of arbitrary size
|
52 |
+
"""
|
53 |
+
if size is None:
|
54 |
+
size=(1, len(im_list))
|
55 |
+
assert len(size)==2
|
56 |
+
if isinstance(im_list[0], np.ndarray):
|
57 |
+
im_list = [Image.fromarray(im) for im in im_list]
|
58 |
+
h, w = size
|
59 |
+
n = len(im_list)
|
60 |
+
canvas = []
|
61 |
+
for i in range(h):
|
62 |
+
start, end = i*w, min((i+1)*w, n)
|
63 |
+
row = combine_horz(im_list[start:end], pad, color)
|
64 |
+
canvas.append(row)
|
65 |
+
canvas = combine_vert(canvas, pad, color)
|
66 |
+
return canvas
|
67 |
+
|
68 |
+
def combine_horz(pil_ims, pad=0, c=255):
|
69 |
+
"""
|
70 |
+
Combines multiple pil_ims into a single side-by-side PIL image object.
|
71 |
+
"""
|
72 |
+
widths, heights = zip(*(i.size for i in pil_ims))
|
73 |
+
total_width = sum(widths) + (len(pil_ims)-1) * pad
|
74 |
+
max_height = max(heights)
|
75 |
+
color = (c,c,c)
|
76 |
+
new_im = Image.new('RGB', (total_width, max_height), color)
|
77 |
+
x_offset = 0
|
78 |
+
for im in pil_ims:
|
79 |
+
new_im.paste(im, (x_offset,0))
|
80 |
+
x_offset += (im.size[0] + pad)
|
81 |
+
return new_im
|
82 |
+
|
83 |
+
|
84 |
+
def combine_vert(pil_ims, pad=0, c=255):
|
85 |
+
"""
|
86 |
+
Combines multiple pil_ims into a single vertical PIL image object.
|
87 |
+
"""
|
88 |
+
widths, heights = zip(*(i.size for i in pil_ims))
|
89 |
+
max_width = max(widths)
|
90 |
+
total_height = sum(heights) + (len(pil_ims)-1)*pad
|
91 |
+
color = (c,c,c)
|
92 |
+
new_im = Image.new('RGB', (max_width, total_height), color)
|
93 |
+
y_offset = 0
|
94 |
+
for im in pil_ims:
|
95 |
+
new_im.paste(im, (0,y_offset))
|
96 |
+
y_offset += (im.size[1] + pad)
|
97 |
+
return new_im
|
98 |
+
|
99 |
+
def make_text_image(img_shape=(100,20), text='hello', font_path=FONT, offset=(0,0), font_size=16):
|
100 |
+
"""
|
101 |
+
make a text image with given width/height and font size
|
102 |
+
Args:
|
103 |
+
img_shape, offset tuple (width, height)
|
104 |
+
font_path path to font file (TrueType)
|
105 |
+
font_size max font size, actual may smaller
|
106 |
+
|
107 |
+
Return:
|
108 |
+
pil image
|
109 |
+
"""
|
110 |
+
im = Image.new('RGB', tuple(img_shape), (255,255,255))
|
111 |
+
draw = ImageDraw.Draw(im)
|
112 |
+
|
113 |
+
def get_font_size(max_font_size):
|
114 |
+
font = ImageFont.truetype(font_path, max_font_size)
|
115 |
+
text_size = font.getsize(text) # (w,h)
|
116 |
+
start_w = int((img_shape[0] - text_size[0]) / 2)
|
117 |
+
start_h = int((img_shape[1] - text_size[1])/2)
|
118 |
+
if start_h <0 or start_w < 0:
|
119 |
+
return get_font_size(max_font_size-2)
|
120 |
+
else:
|
121 |
+
return font, (start_w, start_h)
|
122 |
+
font, pos = get_font_size(font_size)
|
123 |
+
pos = (pos[0]+offset[0], pos[1]+offset[1])
|
124 |
+
draw.text(pos, text, font=font, fill=0)
|
125 |
+
return im
|
126 |
+
|
127 |
+
|
128 |
+
def log_scale(array, epsilon=1e-12):
|
129 |
+
"""Log scale the input array.
|
130 |
+
"""
|
131 |
+
array = np.abs(array)
|
132 |
+
array += epsilon # no zero in log
|
133 |
+
array = np.log(array)
|
134 |
+
return array
|
135 |
+
|
136 |
+
def dct2(array):
|
137 |
+
"""2D DCT"""
|
138 |
+
array = fftpack.dct(array, type=2, norm="ortho", axis=0)
|
139 |
+
array = fftpack.dct(array, type=2, norm="ortho", axis=1)
|
140 |
+
return array
|
141 |
+
|
142 |
+
def idct2(array):
|
143 |
+
"""inverse 2D DCT"""
|
144 |
+
array = fftpack.idct(array, type=2, norm="ortho", axis=0)
|
145 |
+
array = fftpack.idct(array, type=2, norm="ortho", axis=1)
|
146 |
+
return array
|
147 |
+
|
148 |
+
|
149 |
+
class DCT(object):
|
150 |
+
def __init__(self, log=True):
|
151 |
+
self.log = log
|
152 |
+
|
153 |
+
def __call__(self, x):
|
154 |
+
x = np.array(x)
|
155 |
+
x = dct2(x)
|
156 |
+
if self.log:
|
157 |
+
x = log_scale(x)
|
158 |
+
# normalize
|
159 |
+
x = np.clip((x - x.min())/(x.max() - x.min()) * 255, 0, 255).astype(np.uint8)
|
160 |
+
return Image.fromarray(x)
|
161 |
+
|
162 |
+
def __repr__(self):
|
163 |
+
s = f'(Discrete Cosine Transform, logarithm={self.log})'
|
164 |
+
return self.__class__.__name__ + s
|
tools/imgcap_dataset.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
Dataset class for image-caption
|
5 |
+
@author: Tu Bui @University of Surrey
|
6 |
+
"""
|
7 |
+
import json
|
8 |
+
from PIL import Image
|
9 |
+
import numpy as np
|
10 |
+
from pathlib import Path
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import Dataset, DataLoader
|
13 |
+
from functools import partial
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from ldm.util import instantiate_from_config
|
16 |
+
import pandas as pd
|
17 |
+
|
18 |
+
|
19 |
+
def worker_init_fn(_):
|
20 |
+
worker_info = torch.utils.data.get_worker_info()
|
21 |
+
worker_id = worker_info.id
|
22 |
+
return np.random.seed(np.random.get_state()[1][0] + worker_id)
|
23 |
+
|
24 |
+
|
25 |
+
class WrappedDataset(Dataset):
|
26 |
+
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
|
27 |
+
|
28 |
+
def __init__(self, dataset):
|
29 |
+
self.data = dataset
|
30 |
+
|
31 |
+
def __len__(self):
|
32 |
+
return len(self.data)
|
33 |
+
|
34 |
+
def __getitem__(self, idx):
|
35 |
+
return self.data[idx]
|
36 |
+
|
37 |
+
|
38 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
39 |
+
def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False,
|
40 |
+
shuffle_val_dataloader=False):
|
41 |
+
super().__init__()
|
42 |
+
self.batch_size = batch_size
|
43 |
+
self.dataset_configs = dict()
|
44 |
+
self.num_workers = num_workers if num_workers is not None else batch_size * 2
|
45 |
+
self.use_worker_init_fn = use_worker_init_fn
|
46 |
+
if train is not None:
|
47 |
+
self.dataset_configs["train"] = train
|
48 |
+
self.train_dataloader = self._train_dataloader
|
49 |
+
if validation is not None:
|
50 |
+
self.dataset_configs["validation"] = validation
|
51 |
+
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
|
52 |
+
if test is not None:
|
53 |
+
self.dataset_configs["test"] = test
|
54 |
+
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
|
55 |
+
if predict is not None:
|
56 |
+
self.dataset_configs["predict"] = predict
|
57 |
+
self.predict_dataloader = self._predict_dataloader
|
58 |
+
self.wrap = wrap
|
59 |
+
|
60 |
+
def prepare_data(self):
|
61 |
+
for data_cfg in self.dataset_configs.values():
|
62 |
+
instantiate_from_config(data_cfg)
|
63 |
+
|
64 |
+
def setup(self, stage=None):
|
65 |
+
self.datasets = dict(
|
66 |
+
(k, instantiate_from_config(self.dataset_configs[k]))
|
67 |
+
for k in self.dataset_configs)
|
68 |
+
if self.wrap:
|
69 |
+
for k in self.datasets:
|
70 |
+
self.datasets[k] = WrappedDataset(self.datasets[k])
|
71 |
+
|
72 |
+
def _train_dataloader(self):
|
73 |
+
if self.use_worker_init_fn:
|
74 |
+
init_fn = worker_init_fn
|
75 |
+
else:
|
76 |
+
init_fn = None
|
77 |
+
return DataLoader(self.datasets["train"], batch_size=self.batch_size,
|
78 |
+
num_workers=self.num_workers, shuffle=True,
|
79 |
+
worker_init_fn=init_fn)
|
80 |
+
|
81 |
+
def _val_dataloader(self, shuffle=False):
|
82 |
+
if self.use_worker_init_fn:
|
83 |
+
init_fn = worker_init_fn
|
84 |
+
else:
|
85 |
+
init_fn = None
|
86 |
+
return DataLoader(self.datasets["validation"],
|
87 |
+
batch_size=self.batch_size,
|
88 |
+
num_workers=self.num_workers,
|
89 |
+
worker_init_fn=init_fn,
|
90 |
+
shuffle=shuffle)
|
91 |
+
|
92 |
+
def _test_dataloader(self, shuffle=False):
|
93 |
+
if self.use_worker_init_fn:
|
94 |
+
init_fn = worker_init_fn
|
95 |
+
else:
|
96 |
+
init_fn = None
|
97 |
+
|
98 |
+
return DataLoader(self.datasets["test"], batch_size=self.batch_size,
|
99 |
+
num_workers=self.num_workers, worker_init_fn=init_fn, shuffle=shuffle)
|
100 |
+
|
101 |
+
def _predict_dataloader(self, shuffle=False):
|
102 |
+
if self.use_worker_init_fn:
|
103 |
+
init_fn = worker_init_fn
|
104 |
+
else:
|
105 |
+
init_fn = None
|
106 |
+
return DataLoader(self.datasets["predict"], batch_size=self.batch_size,
|
107 |
+
num_workers=self.num_workers, worker_init_fn=init_fn)
|
108 |
+
|
109 |
+
|
110 |
+
class ImageCaptionRaw(Dataset):
|
111 |
+
def __init__(self, image_dir, caption_file, secret_len=100, transform=None):
|
112 |
+
super().__init__()
|
113 |
+
self.image_dir = Path(image_dir)
|
114 |
+
self.data = []
|
115 |
+
with open(caption_file, 'rt') as f:
|
116 |
+
for line in f:
|
117 |
+
self.data.append(json.loads(line))
|
118 |
+
self.secret_len = secret_len
|
119 |
+
self.transform = transform
|
120 |
+
|
121 |
+
def __len__(self):
|
122 |
+
return len(self.data)
|
123 |
+
|
124 |
+
def __getitem__(self, idx):
|
125 |
+
item = self.data[idx]
|
126 |
+
image = Image.open(self.image_dir/item['image']).convert('RGB').resize((512,512))
|
127 |
+
caption = item['captions']
|
128 |
+
cid = torch.randint(0, len(caption), (1,)).item()
|
129 |
+
caption = caption[cid]
|
130 |
+
if self.transform is not None:
|
131 |
+
image = self.transform(image)
|
132 |
+
|
133 |
+
image = np.array(image, dtype=np.float32)/ 255.0 # normalize to [0, 1]
|
134 |
+
target = image * 2.0 - 1.0 # normalize to [-1, 1]
|
135 |
+
secret = torch.zeros(self.secret_len, dtype=torch.float).random_(0, 2)
|
136 |
+
return dict(image=image, caption=caption, target=target, secret=secret)
|
137 |
+
|
138 |
+
|
139 |
+
class BAMFG(Dataset):
|
140 |
+
def __init__(self, style_dir, gt_dir, data_list, transform=None):
|
141 |
+
super().__init__()
|
142 |
+
self.style_dir = Path(style_dir)
|
143 |
+
self.gt_dir = Path(gt_dir)
|
144 |
+
self.data = pd.read_csv(data_list)
|
145 |
+
self.transform = transform
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.data)
|
149 |
+
|
150 |
+
def __getitem__(self, idx):
|
151 |
+
item = self.data.iloc[idx]
|
152 |
+
gt_img = Image.open(self.gt_dir/item['gt_img']).convert('RGB').resize((512,512))
|
153 |
+
style_img = Image.open(self.style_dir/item['style_img']).convert('RGB').resize((512,512))
|
154 |
+
txt = item['prompt']
|
155 |
+
if self.transform is not None:
|
156 |
+
gt_img = self.transform(gt_img)
|
157 |
+
style_img = self.transform(style_img)
|
158 |
+
|
159 |
+
gt_img = np.array(gt_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
|
160 |
+
style_img = np.array(style_img, dtype=np.float32)/ 255.0 # normalize to [0, 1]
|
161 |
+
target = gt_img * 2.0 - 1.0 # normalize to [-1, 1]
|
162 |
+
|
163 |
+
return dict(image=gt_img, txt=txt, hint=style_img)
|
tools/sifid.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from scipy import linalg
|
4 |
+
import torchvision
|
5 |
+
from torchvision import transforms
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
12 |
+
"""Numpy implementation of the Frechet Distance.
|
13 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
14 |
+
and X_2 ~ N(mu_2, C_2) is
|
15 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
16 |
+
Stable version by Dougal J. Sutherland.
|
17 |
+
Params:
|
18 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
19 |
+
inception net (like returned by the function 'get_predictions')
|
20 |
+
for generated samples.
|
21 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
22 |
+
representative data set.
|
23 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
24 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
25 |
+
representative data set.
|
26 |
+
Returns:
|
27 |
+
-- : The Frechet Distance.
|
28 |
+
"""
|
29 |
+
|
30 |
+
mu1 = np.atleast_1d(mu1)
|
31 |
+
mu2 = np.atleast_1d(mu2)
|
32 |
+
|
33 |
+
sigma1 = np.atleast_2d(sigma1)
|
34 |
+
sigma2 = np.atleast_2d(sigma2)
|
35 |
+
|
36 |
+
assert mu1.shape == mu2.shape, \
|
37 |
+
'Training and test mean vectors have different lengths'
|
38 |
+
assert sigma1.shape == sigma2.shape, \
|
39 |
+
'Training and test covariances have different dimensions'
|
40 |
+
|
41 |
+
diff = mu1 - mu2
|
42 |
+
|
43 |
+
# Product might be almost singular
|
44 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
45 |
+
if not np.isfinite(covmean).all():
|
46 |
+
msg = ('fid calculation produces singular product; '
|
47 |
+
'adding %s to diagonal of cov estimates') % eps
|
48 |
+
print(msg)
|
49 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
50 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
51 |
+
|
52 |
+
# Numerical error might give slight imaginary component
|
53 |
+
if np.iscomplexobj(covmean):
|
54 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
55 |
+
m = np.max(np.abs(covmean.imag))
|
56 |
+
raise ValueError('Imaginary component {}'.format(m))
|
57 |
+
covmean = covmean.real
|
58 |
+
|
59 |
+
tr_covmean = np.trace(covmean)
|
60 |
+
|
61 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
62 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
63 |
+
|
64 |
+
|
65 |
+
class SIFID(object):
|
66 |
+
def __init__(self, dims=64) -> None:
|
67 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
|
68 |
+
self.model = InceptionV3([block_idx]).cuda()
|
69 |
+
self.model.eval()
|
70 |
+
self.dims = dims
|
71 |
+
|
72 |
+
def calculate_activation_statistics(self, x):
|
73 |
+
act = self.get_activations(x)
|
74 |
+
mu = np.mean(act, axis=0)
|
75 |
+
sigma = np.cov(act, rowvar=False)
|
76 |
+
return mu, sigma
|
77 |
+
|
78 |
+
def get_activations(self, x):
|
79 |
+
# x tensor (B, C, H, W) in range [0, 1]
|
80 |
+
batch_size = x.shape[0]
|
81 |
+
with torch.no_grad():
|
82 |
+
pred = self.model(x)[0]
|
83 |
+
pred = pred.cpu().numpy()
|
84 |
+
pred = pred.transpose(0, 2, 3, 1).reshape(batch_size*pred.shape[2]*pred.shape[3],-1)
|
85 |
+
return pred
|
86 |
+
|
87 |
+
def __call__(self, x1, x2):
|
88 |
+
# x1, x2 tensor (B, C, H, W) in range [-1, 1]
|
89 |
+
x1, x2 = (x1 + 1.)/2, (x2 + 1.)/2 # [-1, 1] -> [0, 1]
|
90 |
+
m1, s1 = self.calculate_activation_statistics(x1.unsqueeze(0).cuda())
|
91 |
+
m2, s2 = self.calculate_activation_statistics(x2.unsqueeze(0).cuda())
|
92 |
+
return calculate_frechet_distance(m1, s1, m2, s2)
|
93 |
+
|
94 |
+
|
95 |
+
class InceptionV3(nn.Module):
|
96 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
97 |
+
|
98 |
+
# Index of default block of inception to return,
|
99 |
+
# corresponds to output of final average pooling
|
100 |
+
DEFAULT_BLOCK_INDEX = 3
|
101 |
+
|
102 |
+
# Maps feature dimensionality to their output blocks indices
|
103 |
+
BLOCK_INDEX_BY_DIM = {
|
104 |
+
64: 0, # First max pooling features
|
105 |
+
192: 1, # Second max pooling featurs
|
106 |
+
768: 2, # Pre-aux classifier features
|
107 |
+
2048: 3 # Final average pooling features
|
108 |
+
}
|
109 |
+
|
110 |
+
def __init__(self,
|
111 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
112 |
+
resize_input=False,
|
113 |
+
normalize_input=True,
|
114 |
+
requires_grad=False):
|
115 |
+
"""Build pretrained InceptionV3
|
116 |
+
Parameters
|
117 |
+
----------
|
118 |
+
output_blocks : list of int
|
119 |
+
Indices of blocks to return features of. Possible values are:
|
120 |
+
- 0: corresponds to output of first max pooling
|
121 |
+
- 1: corresponds to output of second max pooling
|
122 |
+
- 2: corresponds to output which is fed to aux classifier
|
123 |
+
- 3: corresponds to output of final average pooling
|
124 |
+
resize_input : bool
|
125 |
+
If true, bilinearly resizes input to width and height 299 before
|
126 |
+
feeding input to model. As the network without fully connected
|
127 |
+
layers is fully convolutional, it should be able to handle inputs
|
128 |
+
of arbitrary size, so resizing might not be strictly needed
|
129 |
+
normalize_input : bool
|
130 |
+
If true, scales the input from range (0, 1) to the range the
|
131 |
+
pretrained Inception network expects, namely (-1, 1)
|
132 |
+
requires_grad : bool
|
133 |
+
If true, parameters of the model require gradient. Possibly useful
|
134 |
+
for finetuning the network
|
135 |
+
"""
|
136 |
+
super(InceptionV3, self).__init__()
|
137 |
+
|
138 |
+
self.resize_input = resize_input
|
139 |
+
self.normalize_input = normalize_input
|
140 |
+
self.output_blocks = sorted(output_blocks)
|
141 |
+
self.last_needed_block = max(output_blocks)
|
142 |
+
|
143 |
+
assert self.last_needed_block <= 3, \
|
144 |
+
'Last possible output block index is 3'
|
145 |
+
|
146 |
+
self.blocks = nn.ModuleList()
|
147 |
+
|
148 |
+
inception = torchvision.models.inception_v3(pretrained=True)
|
149 |
+
|
150 |
+
# Block 0: input to maxpool1
|
151 |
+
block0 = [
|
152 |
+
inception.Conv2d_1a_3x3,
|
153 |
+
inception.Conv2d_2a_3x3,
|
154 |
+
inception.Conv2d_2b_3x3,
|
155 |
+
]
|
156 |
+
|
157 |
+
|
158 |
+
self.blocks.append(nn.Sequential(*block0))
|
159 |
+
|
160 |
+
# Block 1: maxpool1 to maxpool2
|
161 |
+
if self.last_needed_block >= 1:
|
162 |
+
block1 = [
|
163 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
164 |
+
inception.Conv2d_3b_1x1,
|
165 |
+
inception.Conv2d_4a_3x3,
|
166 |
+
]
|
167 |
+
self.blocks.append(nn.Sequential(*block1))
|
168 |
+
|
169 |
+
# Block 2: maxpool2 to aux classifier
|
170 |
+
if self.last_needed_block >= 2:
|
171 |
+
block2 = [
|
172 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
173 |
+
inception.Mixed_5b,
|
174 |
+
inception.Mixed_5c,
|
175 |
+
inception.Mixed_5d,
|
176 |
+
inception.Mixed_6a,
|
177 |
+
inception.Mixed_6b,
|
178 |
+
inception.Mixed_6c,
|
179 |
+
inception.Mixed_6d,
|
180 |
+
inception.Mixed_6e,
|
181 |
+
]
|
182 |
+
self.blocks.append(nn.Sequential(*block2))
|
183 |
+
|
184 |
+
# Block 3: aux classifier to final avgpool
|
185 |
+
if self.last_needed_block >= 3:
|
186 |
+
block3 = [
|
187 |
+
inception.Mixed_7a,
|
188 |
+
inception.Mixed_7b,
|
189 |
+
inception.Mixed_7c,
|
190 |
+
]
|
191 |
+
self.blocks.append(nn.Sequential(*block3))
|
192 |
+
|
193 |
+
if self.last_needed_block >= 4:
|
194 |
+
block4 = [
|
195 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
196 |
+
]
|
197 |
+
self.blocks.append(nn.Sequential(*block4))
|
198 |
+
|
199 |
+
for param in self.parameters():
|
200 |
+
param.requires_grad = requires_grad
|
201 |
+
|
202 |
+
def forward(self, inp):
|
203 |
+
"""Get Inception feature maps
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
inp : torch.autograd.Variable
|
207 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
208 |
+
range (0, 1)
|
209 |
+
Returns
|
210 |
+
-------
|
211 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
212 |
+
block, sorted ascending by index
|
213 |
+
"""
|
214 |
+
outp = []
|
215 |
+
x = inp
|
216 |
+
|
217 |
+
if self.resize_input:
|
218 |
+
x = F.upsample(x,
|
219 |
+
size=(299, 299),
|
220 |
+
mode='bilinear',
|
221 |
+
align_corners=False)
|
222 |
+
|
223 |
+
if self.normalize_input:
|
224 |
+
x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
|
225 |
+
|
226 |
+
for idx, block in enumerate(self.blocks):
|
227 |
+
x = block(x)
|
228 |
+
if idx in self.output_blocks:
|
229 |
+
outp.append(x)
|
230 |
+
|
231 |
+
if idx == self.last_needed_block:
|
232 |
+
break
|
233 |
+
|
234 |
+
return outp
|
235 |
+
|
236 |
+
if __name__ == '__main__':
|
237 |
+
tform = transforms.Compose([transforms.Resize((256,256)),
|
238 |
+
transforms.ToTensor(),
|
239 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
|
240 |
+
im1 = Image.open('test1.jpg')
|
241 |
+
im2 = Image.open('test2.jpg')
|
242 |
+
im1 = tform(im1) # 3xHxW in [-1,]
|
243 |
+
im2 = tform(im2)
|
244 |
+
sifid_model = SIFID()
|
245 |
+
sifid_score = sifid_model(im1, im2)
|
246 |
+
print(sifid_score)
|
tools/slack_bot.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
slack_bot.py
|
5 |
+
Created on May 02 2020 11:02
|
6 |
+
a bot to send message/image during program run
|
7 |
+
@author: Tu Bui tu@surrey.ac.uk
|
8 |
+
"""
|
9 |
+
|
10 |
+
from __future__ import absolute_import
|
11 |
+
from __future__ import division
|
12 |
+
from __future__ import print_function
|
13 |
+
import os
|
14 |
+
import sys
|
15 |
+
import requests
|
16 |
+
import socket
|
17 |
+
from slack import WebClient
|
18 |
+
from slack.errors import SlackApiError
|
19 |
+
import threading
|
20 |
+
|
21 |
+
|
22 |
+
SLACK_MAX_PRINT_ERROR = 3
|
23 |
+
SLACK_ERROR_CODE = {'not_active': 1,
|
24 |
+
'API': 2}
|
25 |
+
|
26 |
+
|
27 |
+
def welcome_message():
|
28 |
+
hostname = socket.gethostname()
|
29 |
+
all_args = ' '.join(sys.argv)
|
30 |
+
out_text = 'On server {}: {}\n'.format(hostname, all_args)
|
31 |
+
return out_text
|
32 |
+
|
33 |
+
|
34 |
+
class Notifier(object):
|
35 |
+
"""
|
36 |
+
A slack bot to send text/image to a given workspace channel.
|
37 |
+
This class initializes with a text file as input, the text file should contain 2 lines:
|
38 |
+
slack token
|
39 |
+
slack channel
|
40 |
+
|
41 |
+
Usage:
|
42 |
+
msg = Notifier(token_file)
|
43 |
+
msg.send_initial_text(' '.join(sys.argv))
|
44 |
+
msg.send_text('hi, this text is inside slack thread')
|
45 |
+
msg.send_file(your_file, 'file title')
|
46 |
+
"""
|
47 |
+
def __init__(self, token_file):
|
48 |
+
"""
|
49 |
+
setup slack
|
50 |
+
:param token_file: path to slack token file
|
51 |
+
"""
|
52 |
+
self.active = True
|
53 |
+
self.thread_id = None
|
54 |
+
self.counter = 0 # count number of errors during Web API call
|
55 |
+
if not os.path.exists(token_file):
|
56 |
+
print('[SLACK] token file not found. You will not be notified.')
|
57 |
+
self.active = False
|
58 |
+
else:
|
59 |
+
try:
|
60 |
+
with open(token_file, 'r') as f:
|
61 |
+
lines = f.readlines()
|
62 |
+
self.token = lines[0].strip()
|
63 |
+
self.channel = lines[1].strip()
|
64 |
+
except Exception as e:
|
65 |
+
print(e)
|
66 |
+
print('[SLACK] fail to read token file. You will not be notified.')
|
67 |
+
self.active = False
|
68 |
+
|
69 |
+
def _handel_error(self, e):
|
70 |
+
assert e.response["ok"] is False
|
71 |
+
assert e.response["error"] # str like 'invalid_auth', 'channel_not_found'
|
72 |
+
self.counter += 1
|
73 |
+
if self.counter <= SLACK_MAX_PRINT_ERROR:
|
74 |
+
print(f"Got the following error, you will not be notified: {e.response['error']}")
|
75 |
+
|
76 |
+
def send_init_text(self, text=None):
|
77 |
+
"""
|
78 |
+
start a new thread with a main message and register the thread id
|
79 |
+
:param text: initial message for this thread
|
80 |
+
:return:
|
81 |
+
"""
|
82 |
+
if not self.active:
|
83 |
+
return SLACK_ERROR_CODE['not_active']
|
84 |
+
try:
|
85 |
+
if text is None:
|
86 |
+
text = welcome_message()
|
87 |
+
sc = WebClient(self.token)
|
88 |
+
response = sc.chat_postMessage(channel=self.channel, text=text)
|
89 |
+
self.thread_id = response['ts']
|
90 |
+
except SlackApiError as e:
|
91 |
+
self._handel_error(e)
|
92 |
+
return SLACK_ERROR_CODE['API']
|
93 |
+
print('[SLACK] sent initial text. Chat ID %s. Message %s' % (self.thread_id, text))
|
94 |
+
return 0
|
95 |
+
|
96 |
+
def send_init_file(self, file_path, title=''):
|
97 |
+
"""
|
98 |
+
start a new thread with a file and register thread id
|
99 |
+
:param file_path: path to file
|
100 |
+
:param title: title of this file
|
101 |
+
:return: 0 if success otherwise error code
|
102 |
+
"""
|
103 |
+
if not self.active:
|
104 |
+
return SLACK_ERROR_CODE['not_active']
|
105 |
+
try:
|
106 |
+
response = sc.files_upload(title=title, channels=self.channel, file=file_path)
|
107 |
+
self.thread_id = response['ts']
|
108 |
+
except SlackApiError as e:
|
109 |
+
self._handel_error(e)
|
110 |
+
return SLACK_ERROR_CODE['API']
|
111 |
+
print('[SLACK] sent initial file. Chat ID %s.' % self.thread_id)
|
112 |
+
return 0
|
113 |
+
|
114 |
+
def send_text(self, text, reply_broadcast=False):
|
115 |
+
"""
|
116 |
+
send text as a thread if one is registered in self.thread_id.
|
117 |
+
Otherwise send as a new message
|
118 |
+
:param text: message to send.
|
119 |
+
:return: 0 if success, error code otherwise
|
120 |
+
"""
|
121 |
+
print(text)
|
122 |
+
if not self.active:
|
123 |
+
return SLACK_ERROR_CODE['not_active']
|
124 |
+
if self.thread_id is None:
|
125 |
+
self.send_init_text(text)
|
126 |
+
else:
|
127 |
+
try:
|
128 |
+
sc = WebClient(self.token)
|
129 |
+
response = sc.chat_postMessage(channel=self.channel, text=text,
|
130 |
+
thread_ts=self.thread_id, as_user=True,
|
131 |
+
reply_broadcast=reply_broadcast)
|
132 |
+
except SlackApiError as e:
|
133 |
+
self._handel_error(e)
|
134 |
+
return SLACK_ERROR_CODE['API']
|
135 |
+
return 0
|
136 |
+
|
137 |
+
def _send_file(self, file_path, title='', reply_broadcast=False):
|
138 |
+
"""can be multithread target"""
|
139 |
+
try:
|
140 |
+
sc = WebClient(self.token)
|
141 |
+
sc.files_upload(title=title, channels=self.channel,
|
142 |
+
thread_ts=self.thread_id, file=file_path,
|
143 |
+
reply_broadcast=reply_broadcast)
|
144 |
+
except SlackApiError as e:
|
145 |
+
self._handel_error(e)
|
146 |
+
return SLACK_ERROR_CODE['API']
|
147 |
+
return 0
|
148 |
+
|
149 |
+
def send_file(self, file_path, title='', reply_broadcast=False):
|
150 |
+
if not self.active:
|
151 |
+
return SLACK_ERROR_CODE['not_active']
|
152 |
+
if self.thread_id is None:
|
153 |
+
return self.send_init_file(file_path, title)
|
154 |
+
else:
|
155 |
+
os_thread = threading.Thread(target=self._send_file, args=(file_path, title, reply_broadcast))
|
156 |
+
os_thread.start()
|
157 |
+
return 0 # may still have error if _send_file() fail
|