johnowhitaker commited on
Commit
41333da
1 Parent(s): d0ff3f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +262 -0
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Variable
5
+ import torch.optim as optim
6
+
7
+ from imstack.core import ImStack
8
+ from tqdm.notebook import tqdm
9
+
10
+ import kornia.augmentation as K
11
+ from CLIP import clip
12
+ from torchvision import transforms
13
+
14
+ from PIL import Image
15
+ import numpy as np
16
+ import math
17
+
18
+ from matplotlib import pyplot as plt
19
+ from fastprogress.fastprogress import master_bar, progress_bar
20
+ from IPython.display import HTML
21
+ from base64 import b64encode
22
+
23
+ import warnings
24
+ warnings.filterwarnings('ignore') # Some pytorch functions give warnings about behaviour changes that I don't want to see over and over again :)
25
+
26
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
27
+
28
+ def sinc(x):
29
+ return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
30
+
31
+
32
+ def lanczos(x, a):
33
+ cond = torch.logical_and(-a < x, x < a)
34
+ out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
35
+ return out / out.sum()
36
+
37
+
38
+ def ramp(ratio, width):
39
+ n = math.ceil(width / ratio + 1)
40
+ out = torch.empty([n])
41
+ cur = 0
42
+ for i in range(out.shape[0]):
43
+ out[i] = cur
44
+ cur += ratio
45
+ return torch.cat([-out[1:].flip([0]), out])[1:-1]
46
+
47
+ class Prompt(nn.Module):
48
+ def __init__(self, embed, weight=1., stop=float('-inf')):
49
+ super().__init__()
50
+ self.register_buffer('embed', embed)
51
+ self.register_buffer('weight', torch.as_tensor(weight))
52
+ self.register_buffer('stop', torch.as_tensor(stop))
53
+
54
+ def forward(self, input):
55
+ input_normed = F.normalize(input.unsqueeze(1), dim=2)
56
+ embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
57
+ dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
58
+ dists = dists * self.weight.sign()
59
+ return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
60
+
61
+ class MakeCutouts(nn.Module):
62
+ def __init__(self, cut_size, cutn, cut_pow=1.):
63
+ super().__init__()
64
+ self.cut_size = cut_size
65
+ self.cutn = cutn
66
+ self.cut_pow = cut_pow
67
+ self.augs = nn.Sequential(
68
+ K.RandomHorizontalFlip(p=0.5),
69
+ K.RandomSharpness(0.3,p=0.4),
70
+ K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
71
+ K.RandomPerspective(0.2,p=0.4),
72
+ K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
73
+ self.noise_fac = 0.1
74
+
75
+ def forward(self, input):
76
+ sideY, sideX = input.shape[2:4]
77
+ max_size = min(sideX, sideY)
78
+ min_size = min(sideX, sideY, self.cut_size)
79
+ cutouts = []
80
+ for _ in range(self.cutn):
81
+ size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
82
+ offsetx = torch.randint(0, sideX - size + 1, ())
83
+ offsety = torch.randint(0, sideY - size + 1, ())
84
+ cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
85
+ cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
86
+ batch = self.augs(torch.cat(cutouts, dim=0))
87
+ if self.noise_fac:
88
+ facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
89
+ batch = batch + facs * torch.randn_like(batch)
90
+ return batch
91
+
92
+ def resample(input, size, align_corners=True):
93
+ n, c, h, w = input.shape
94
+ dh, dw = size
95
+
96
+ input = input.view([n * c, 1, h, w])
97
+
98
+ if dh < h:
99
+ kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
100
+ pad_h = (kernel_h.shape[0] - 1) // 2
101
+ input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
102
+ input = F.conv2d(input, kernel_h[None, None, :, None])
103
+
104
+ if dw < w:
105
+ kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
106
+ pad_w = (kernel_w.shape[0] - 1) // 2
107
+ input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
108
+ input = F.conv2d(input, kernel_w[None, None, None, :])
109
+
110
+ input = input.view([n, c, h, w])
111
+ return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
112
+
113
+ class ReplaceGrad(torch.autograd.Function):
114
+ @staticmethod
115
+ def forward(ctx, x_forward, x_backward):
116
+ ctx.shape = x_backward.shape
117
+ return x_forward
118
+
119
+ @staticmethod
120
+ def backward(ctx, grad_in):
121
+ return None, grad_in.sum_to_size(ctx.shape)
122
+
123
+
124
+ replace_grad = ReplaceGrad.apply
125
+
126
+ #Load CLOOB model
127
+ import sys
128
+ sys.path.append('./cloob-training')
129
+ from cloob_training import model_pt, pretrained
130
+
131
+ config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
132
+ cloob = model_pt.get_pt_model(config)
133
+ checkpoint = pretrained.download_checkpoint(config)
134
+ cloob.load_state_dict(model_pt.get_pt_params(config, checkpoint))
135
+ cloob.eval().requires_grad_(False).to(device)
136
+ print('done')
137
+
138
+ # Load fastai model
139
+
140
+ import gradio as gr
141
+ from fastai.vision.all import *
142
+ from os.path import exists
143
+ import requests
144
+
145
+ model_fn = 'quick_224px'
146
+ url = 'https://huggingface.co/johnowhitaker/sketchy_unet_rn34/resolve/main/quick_224px'
147
+
148
+ if not exists(model_fn):
149
+ print('starting download')
150
+ with requests.get(url, stream=True) as r:
151
+ r.raise_for_status()
152
+ with open(model_fn, 'wb') as f:
153
+ for chunk in r.iter_content(chunk_size=8192):
154
+ f.write(chunk)
155
+ print('done')
156
+ else:
157
+ print('file exists')
158
+
159
+ def get_x(item):return None
160
+ def get_y(item):return None
161
+ sketch_model = load_learner(model_fn)
162
+
163
+ # Cutouts
164
+ cutn=16
165
+ cut_pow=1
166
+ make_cutouts = MakeCutouts(cloob.config['image_encoder']['image_size'], cutn, cut_pow)
167
+
168
+ def process_im(image_path,
169
+ sketchify_first=True,
170
+ prompt='A watercolor painting of a face',
171
+ lr=0.03,
172
+ n_iter=10
173
+ ):
174
+
175
+ n_iter = int(n_iter)
176
+
177
+ pil_im = None
178
+
179
+ if sketchify_first:
180
+ pred = sketch_model.predict(image_path)
181
+ np_im = pred[0].permute(1, 2, 0).numpy()
182
+ pil_im = Image.fromarray(np_im.astype(np.uint8))
183
+ else:
184
+ pil_im = Image.open(image_path).resize((540, 540))
185
+
186
+
187
+ prompt_texts = [prompt]
188
+ weight_decay=1e-4
189
+
190
+ out_size=540
191
+ base_size=8
192
+ n_layers=5
193
+ scale=3
194
+ layer_decay = 0.3
195
+
196
+
197
+ # The prompts
198
+ p_prompts = []
199
+ for pr in prompt_texts:
200
+ embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
201
+ p_prompts.append(Prompt(embed, 1, float('-inf')).to(device)) # 1 is the weight
202
+
203
+ # Some negative prompts
204
+ n_prompts = []
205
+ for pr in ["Random noise", 'saturated rainbow RGB deep dream']:
206
+ embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
207
+ n_prompts.append(Prompt(embed, 0.5, float('-inf')).to(device)) # 0.5 is the weight
208
+
209
+ # The ImageStack - trying a different scale and n_layers
210
+ ims = ImStack(base_size=base_size,
211
+ scale=scale,
212
+ n_layers=n_layers,
213
+ out_size=out_size,
214
+ decay=layer_decay,
215
+ init_image = pil_im)
216
+
217
+ # desaturate starting image
218
+ desat = 0.6#@param
219
+
220
+ if desat != 1:
221
+ for i in range(n_layers):
222
+ ims.layers[i] = ims.layers[i].detach()*desat
223
+ ims.layers[i].requires_grad = True
224
+
225
+
226
+ optimizer = optim.Adam(ims.layers, lr=lr, weight_decay=weight_decay)
227
+ losses = []
228
+
229
+ for i in tqdm(range(n_iter)):
230
+ optimizer.zero_grad()
231
+
232
+ im = ims()
233
+ batch = cloob.normalize(make_cutouts(im))
234
+ iii = cloob.image_encoder(batch).float()
235
+
236
+ l = 0
237
+ for prompt in p_prompts:
238
+ l += prompt(iii)
239
+ for prompt in n_prompts:
240
+ l -= prompt(iii)
241
+
242
+ losses.append(float(l.detach().cpu()))
243
+ l.backward() # Backprop
244
+ optimizer.step() # Update
245
+
246
+ return ims.to_pil()
247
+
248
+ from gradio.inputs import Checkbox
249
+ iface = gr.Interface(fn=process_im,
250
+ inputs=[
251
+ gr.inputs.Image(label="Input Image", shape=(512, 512), type="filepath"),
252
+ gr.inputs.Checkbox(label='Sketchify First', default=True),
253
+ gr.inputs.Textbox(default="A charcoal and watercolor sketch of a person", label="Prompt"),
254
+ gr.inputs.Number(default=0.03, label='LR'),
255
+ gr.inputs.Number(default=10, label='num_steps'),
256
+
257
+ ],
258
+ outputs=[gr.outputs.Image(type="pil", label="Model Output")],
259
+ title = 'Sketchy ImStack + CLOOB', description = "Stylize an image with ImStack+CLOOB after a Sketchy Unet",
260
+ article = "More info on datasciencecastnet.home.blog"
261
+ )
262
+ iface.launch(enable_queue=True)