Spaces:
Runtime error
Runtime error
johnowhitaker
commited on
Commit
•
41333da
1
Parent(s):
d0ff3f2
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.autograd import Variable
|
5 |
+
import torch.optim as optim
|
6 |
+
|
7 |
+
from imstack.core import ImStack
|
8 |
+
from tqdm.notebook import tqdm
|
9 |
+
|
10 |
+
import kornia.augmentation as K
|
11 |
+
from CLIP import clip
|
12 |
+
from torchvision import transforms
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
import numpy as np
|
16 |
+
import math
|
17 |
+
|
18 |
+
from matplotlib import pyplot as plt
|
19 |
+
from fastprogress.fastprogress import master_bar, progress_bar
|
20 |
+
from IPython.display import HTML
|
21 |
+
from base64 import b64encode
|
22 |
+
|
23 |
+
import warnings
|
24 |
+
warnings.filterwarnings('ignore') # Some pytorch functions give warnings about behaviour changes that I don't want to see over and over again :)
|
25 |
+
|
26 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
27 |
+
|
28 |
+
def sinc(x):
|
29 |
+
return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
|
30 |
+
|
31 |
+
|
32 |
+
def lanczos(x, a):
|
33 |
+
cond = torch.logical_and(-a < x, x < a)
|
34 |
+
out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
|
35 |
+
return out / out.sum()
|
36 |
+
|
37 |
+
|
38 |
+
def ramp(ratio, width):
|
39 |
+
n = math.ceil(width / ratio + 1)
|
40 |
+
out = torch.empty([n])
|
41 |
+
cur = 0
|
42 |
+
for i in range(out.shape[0]):
|
43 |
+
out[i] = cur
|
44 |
+
cur += ratio
|
45 |
+
return torch.cat([-out[1:].flip([0]), out])[1:-1]
|
46 |
+
|
47 |
+
class Prompt(nn.Module):
|
48 |
+
def __init__(self, embed, weight=1., stop=float('-inf')):
|
49 |
+
super().__init__()
|
50 |
+
self.register_buffer('embed', embed)
|
51 |
+
self.register_buffer('weight', torch.as_tensor(weight))
|
52 |
+
self.register_buffer('stop', torch.as_tensor(stop))
|
53 |
+
|
54 |
+
def forward(self, input):
|
55 |
+
input_normed = F.normalize(input.unsqueeze(1), dim=2)
|
56 |
+
embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
|
57 |
+
dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
|
58 |
+
dists = dists * self.weight.sign()
|
59 |
+
return self.weight.abs() * replace_grad(dists, torch.maximum(dists, self.stop)).mean()
|
60 |
+
|
61 |
+
class MakeCutouts(nn.Module):
|
62 |
+
def __init__(self, cut_size, cutn, cut_pow=1.):
|
63 |
+
super().__init__()
|
64 |
+
self.cut_size = cut_size
|
65 |
+
self.cutn = cutn
|
66 |
+
self.cut_pow = cut_pow
|
67 |
+
self.augs = nn.Sequential(
|
68 |
+
K.RandomHorizontalFlip(p=0.5),
|
69 |
+
K.RandomSharpness(0.3,p=0.4),
|
70 |
+
K.RandomAffine(degrees=30, translate=0.1, p=0.8, padding_mode='border'),
|
71 |
+
K.RandomPerspective(0.2,p=0.4),
|
72 |
+
K.ColorJitter(hue=0.01, saturation=0.01, p=0.7))
|
73 |
+
self.noise_fac = 0.1
|
74 |
+
|
75 |
+
def forward(self, input):
|
76 |
+
sideY, sideX = input.shape[2:4]
|
77 |
+
max_size = min(sideX, sideY)
|
78 |
+
min_size = min(sideX, sideY, self.cut_size)
|
79 |
+
cutouts = []
|
80 |
+
for _ in range(self.cutn):
|
81 |
+
size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
|
82 |
+
offsetx = torch.randint(0, sideX - size + 1, ())
|
83 |
+
offsety = torch.randint(0, sideY - size + 1, ())
|
84 |
+
cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
|
85 |
+
cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
|
86 |
+
batch = self.augs(torch.cat(cutouts, dim=0))
|
87 |
+
if self.noise_fac:
|
88 |
+
facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
|
89 |
+
batch = batch + facs * torch.randn_like(batch)
|
90 |
+
return batch
|
91 |
+
|
92 |
+
def resample(input, size, align_corners=True):
|
93 |
+
n, c, h, w = input.shape
|
94 |
+
dh, dw = size
|
95 |
+
|
96 |
+
input = input.view([n * c, 1, h, w])
|
97 |
+
|
98 |
+
if dh < h:
|
99 |
+
kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
|
100 |
+
pad_h = (kernel_h.shape[0] - 1) // 2
|
101 |
+
input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
|
102 |
+
input = F.conv2d(input, kernel_h[None, None, :, None])
|
103 |
+
|
104 |
+
if dw < w:
|
105 |
+
kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
|
106 |
+
pad_w = (kernel_w.shape[0] - 1) // 2
|
107 |
+
input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
|
108 |
+
input = F.conv2d(input, kernel_w[None, None, None, :])
|
109 |
+
|
110 |
+
input = input.view([n, c, h, w])
|
111 |
+
return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
|
112 |
+
|
113 |
+
class ReplaceGrad(torch.autograd.Function):
|
114 |
+
@staticmethod
|
115 |
+
def forward(ctx, x_forward, x_backward):
|
116 |
+
ctx.shape = x_backward.shape
|
117 |
+
return x_forward
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def backward(ctx, grad_in):
|
121 |
+
return None, grad_in.sum_to_size(ctx.shape)
|
122 |
+
|
123 |
+
|
124 |
+
replace_grad = ReplaceGrad.apply
|
125 |
+
|
126 |
+
#Load CLOOB model
|
127 |
+
import sys
|
128 |
+
sys.path.append('./cloob-training')
|
129 |
+
from cloob_training import model_pt, pretrained
|
130 |
+
|
131 |
+
config = pretrained.get_config('cloob_laion_400m_vit_b_16_16_epochs')
|
132 |
+
cloob = model_pt.get_pt_model(config)
|
133 |
+
checkpoint = pretrained.download_checkpoint(config)
|
134 |
+
cloob.load_state_dict(model_pt.get_pt_params(config, checkpoint))
|
135 |
+
cloob.eval().requires_grad_(False).to(device)
|
136 |
+
print('done')
|
137 |
+
|
138 |
+
# Load fastai model
|
139 |
+
|
140 |
+
import gradio as gr
|
141 |
+
from fastai.vision.all import *
|
142 |
+
from os.path import exists
|
143 |
+
import requests
|
144 |
+
|
145 |
+
model_fn = 'quick_224px'
|
146 |
+
url = 'https://huggingface.co/johnowhitaker/sketchy_unet_rn34/resolve/main/quick_224px'
|
147 |
+
|
148 |
+
if not exists(model_fn):
|
149 |
+
print('starting download')
|
150 |
+
with requests.get(url, stream=True) as r:
|
151 |
+
r.raise_for_status()
|
152 |
+
with open(model_fn, 'wb') as f:
|
153 |
+
for chunk in r.iter_content(chunk_size=8192):
|
154 |
+
f.write(chunk)
|
155 |
+
print('done')
|
156 |
+
else:
|
157 |
+
print('file exists')
|
158 |
+
|
159 |
+
def get_x(item):return None
|
160 |
+
def get_y(item):return None
|
161 |
+
sketch_model = load_learner(model_fn)
|
162 |
+
|
163 |
+
# Cutouts
|
164 |
+
cutn=16
|
165 |
+
cut_pow=1
|
166 |
+
make_cutouts = MakeCutouts(cloob.config['image_encoder']['image_size'], cutn, cut_pow)
|
167 |
+
|
168 |
+
def process_im(image_path,
|
169 |
+
sketchify_first=True,
|
170 |
+
prompt='A watercolor painting of a face',
|
171 |
+
lr=0.03,
|
172 |
+
n_iter=10
|
173 |
+
):
|
174 |
+
|
175 |
+
n_iter = int(n_iter)
|
176 |
+
|
177 |
+
pil_im = None
|
178 |
+
|
179 |
+
if sketchify_first:
|
180 |
+
pred = sketch_model.predict(image_path)
|
181 |
+
np_im = pred[0].permute(1, 2, 0).numpy()
|
182 |
+
pil_im = Image.fromarray(np_im.astype(np.uint8))
|
183 |
+
else:
|
184 |
+
pil_im = Image.open(image_path).resize((540, 540))
|
185 |
+
|
186 |
+
|
187 |
+
prompt_texts = [prompt]
|
188 |
+
weight_decay=1e-4
|
189 |
+
|
190 |
+
out_size=540
|
191 |
+
base_size=8
|
192 |
+
n_layers=5
|
193 |
+
scale=3
|
194 |
+
layer_decay = 0.3
|
195 |
+
|
196 |
+
|
197 |
+
# The prompts
|
198 |
+
p_prompts = []
|
199 |
+
for pr in prompt_texts:
|
200 |
+
embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
|
201 |
+
p_prompts.append(Prompt(embed, 1, float('-inf')).to(device)) # 1 is the weight
|
202 |
+
|
203 |
+
# Some negative prompts
|
204 |
+
n_prompts = []
|
205 |
+
for pr in ["Random noise", 'saturated rainbow RGB deep dream']:
|
206 |
+
embed = cloob.text_encoder(cloob.tokenize(pr).to(device)).float()
|
207 |
+
n_prompts.append(Prompt(embed, 0.5, float('-inf')).to(device)) # 0.5 is the weight
|
208 |
+
|
209 |
+
# The ImageStack - trying a different scale and n_layers
|
210 |
+
ims = ImStack(base_size=base_size,
|
211 |
+
scale=scale,
|
212 |
+
n_layers=n_layers,
|
213 |
+
out_size=out_size,
|
214 |
+
decay=layer_decay,
|
215 |
+
init_image = pil_im)
|
216 |
+
|
217 |
+
# desaturate starting image
|
218 |
+
desat = 0.6#@param
|
219 |
+
|
220 |
+
if desat != 1:
|
221 |
+
for i in range(n_layers):
|
222 |
+
ims.layers[i] = ims.layers[i].detach()*desat
|
223 |
+
ims.layers[i].requires_grad = True
|
224 |
+
|
225 |
+
|
226 |
+
optimizer = optim.Adam(ims.layers, lr=lr, weight_decay=weight_decay)
|
227 |
+
losses = []
|
228 |
+
|
229 |
+
for i in tqdm(range(n_iter)):
|
230 |
+
optimizer.zero_grad()
|
231 |
+
|
232 |
+
im = ims()
|
233 |
+
batch = cloob.normalize(make_cutouts(im))
|
234 |
+
iii = cloob.image_encoder(batch).float()
|
235 |
+
|
236 |
+
l = 0
|
237 |
+
for prompt in p_prompts:
|
238 |
+
l += prompt(iii)
|
239 |
+
for prompt in n_prompts:
|
240 |
+
l -= prompt(iii)
|
241 |
+
|
242 |
+
losses.append(float(l.detach().cpu()))
|
243 |
+
l.backward() # Backprop
|
244 |
+
optimizer.step() # Update
|
245 |
+
|
246 |
+
return ims.to_pil()
|
247 |
+
|
248 |
+
from gradio.inputs import Checkbox
|
249 |
+
iface = gr.Interface(fn=process_im,
|
250 |
+
inputs=[
|
251 |
+
gr.inputs.Image(label="Input Image", shape=(512, 512), type="filepath"),
|
252 |
+
gr.inputs.Checkbox(label='Sketchify First', default=True),
|
253 |
+
gr.inputs.Textbox(default="A charcoal and watercolor sketch of a person", label="Prompt"),
|
254 |
+
gr.inputs.Number(default=0.03, label='LR'),
|
255 |
+
gr.inputs.Number(default=10, label='num_steps'),
|
256 |
+
|
257 |
+
],
|
258 |
+
outputs=[gr.outputs.Image(type="pil", label="Model Output")],
|
259 |
+
title = 'Sketchy ImStack + CLOOB', description = "Stylize an image with ImStack+CLOOB after a Sketchy Unet",
|
260 |
+
article = "More info on datasciencecastnet.home.blog"
|
261 |
+
)
|
262 |
+
iface.launch(enable_queue=True)
|