rynmurdock commited on
Commit
f360117
β€’
1 Parent(s): 8e0b547

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +213 -0
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ DEVICE = 'cpu'
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ from sklearn.svm import LinearSVC
8
+ from sklearn import preprocessing
9
+ import pandas as pd
10
+ import kornia
11
+ import torchvision
12
+
13
+ import random
14
+ import time
15
+
16
+ from diffusers import LCMScheduler
17
+ from diffusers.models import ImageProjection
18
+ from patch_sdxl import SDEmb
19
+ import torch
20
+
21
+
22
+ prompt_list = [p for p in list(set(
23
+ pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
24
+
25
+
26
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
27
+ lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
28
+
29
+ pipe = SDEmb.from_pretrained(model_id, variant="fp16")
30
+ pipe.load_lora_weights(lcm_lora_id)
31
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
32
+ pipe.to(device=DEVICE, dtype=torch.float16)
33
+
34
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
35
+
36
+
37
+
38
+ calibrate_prompts = [
39
+ "4k photo",
40
+ 'surrealist art',
41
+ 'a psychedelic, fractal view',
42
+ 'a beautiful collage',
43
+ 'an intricate portrait',
44
+ 'an impressionist painting',
45
+ 'abstract art',
46
+ 'an eldritch image',
47
+ 'a sketch',
48
+ 'a city full of darkness and graffiti',
49
+ 'a black & white photo',
50
+ 'a brilliant, timeless tarot card of the world',
51
+ 'a photo of a woman',
52
+ '',
53
+ ]
54
+
55
+ embs = []
56
+ ys = []
57
+
58
+ start_time = time.time()
59
+
60
+ output_hidden_state = False if isinstance(pipe.unet.encoder_hid_proj, ImageProjection) else True
61
+
62
+
63
+ transform = kornia.augmentation.RandomResizedCrop(size=(224, 224), scale=(.3, .5))
64
+ nom = torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
65
+ def patch_encode_image(image):
66
+ image = torch.tensor(torchvision.transforms.functional.pil_to_tensor(image).to(torch.float16)).repeat(16, 1, 1, 1).to(DEVICE)
67
+ image = image / 255
68
+ patches = nom(transform(image))
69
+ output, _ = pipe.encode_image(
70
+ patches, DEVICE, 1, output_hidden_state
71
+ )
72
+ return output.mean(0, keepdim=True)
73
+
74
+
75
+ glob_idx = 0
76
+
77
+ def next_image():
78
+ global glob_idx
79
+ glob_idx = glob_idx + 1
80
+ with torch.no_grad():
81
+ if len(calibrate_prompts) > 0:
82
+ print('######### Calibrating with sample prompts #########')
83
+ prompt = calibrate_prompts.pop(0)
84
+ print(prompt)
85
+
86
+ image = pipe(
87
+ prompt=prompt,
88
+ height=1024,
89
+ width=1024,
90
+ num_inference_steps=8,
91
+ guidance_scale=0,
92
+ ip_adapter_emb=torch.zeros(1, 1, 1280, device=DEVICE, dtype=torch.float16),
93
+ ).images
94
+
95
+
96
+ pooled_embeds, _ = pipe.encode_image(
97
+ image[0], DEVICE, 1, output_hidden_state
98
+ )
99
+ #pooled_embeds = patch_encode_image(image[0])
100
+
101
+ embs.append(pooled_embeds)
102
+ return image[0]
103
+ else:
104
+ print('######### Roaming #########')
105
+
106
+ # sample only as many negatives as there are positives
107
+ indices = range(len(ys))
108
+ pos_indices = [i for i in indices if ys[i] == 1]
109
+ neg_indices = [i for i in indices if ys[i] == 0]
110
+ lower = min(len(pos_indices), len(neg_indices))
111
+ neg_indices = random.sample(neg_indices, lower)
112
+ pos_indices = random.sample(pos_indices, lower)
113
+
114
+ cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices]
115
+ cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices]
116
+
117
+ feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs])
118
+ scaler = preprocessing.StandardScaler().fit(feature_embs)
119
+ feature_embs = scaler.transform(feature_embs)
120
+ print(np.array(feature_embs).shape, np.array(ys).shape)
121
+
122
+ lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys))
123
+ lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
124
+ lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0)
125
+
126
+
127
+ rng_prompt = random.choice(prompt_list)
128
+
129
+ w = 1# if len(embs) % 2 == 0 else 0
130
+ im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16)
131
+ prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
132
+ print(prompt)
133
+
134
+ image = pipe(
135
+ prompt=prompt,
136
+ ip_adapter_emb=im_emb,
137
+ height=1024,
138
+ width=1024,
139
+ num_inference_steps=8,
140
+ guidance_scale=0,
141
+ ).images
142
+
143
+ im_emb, _ = pipe.encode_image(
144
+ image[0], DEVICE, 1, output_hidden_state
145
+ )
146
+ #im_emb = patch_encode_image(image[0])
147
+
148
+ embs.append(im_emb)
149
+
150
+ torch.save(lin_class.coef_, f'./{start_time}.pt')
151
+ return image[0]
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+ def start(_):
162
+ return [
163
+ gr.Button(value='Like', interactive=True),
164
+ gr.Button(value='Neither', interactive=True),
165
+ gr.Button(value='Dislike', interactive=True),
166
+ gr.Button(value='Start', interactive=False),
167
+ next_image()
168
+ ]
169
+
170
+
171
+ def choose(choice):
172
+ if choice == 'Like':
173
+ choice = 1
174
+ elif choice == 'Neither':
175
+ _ = embs.pop(-1)
176
+ return next_image()
177
+ else:
178
+ choice = 0
179
+ ys.append(choice)
180
+ return next_image()
181
+
182
+ css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}"
183
+ with gr.Blocks(css=css) as demo:
184
+ with gr.Row():
185
+ html = gr.HTML('''<div style='text-align:center; font-size:32'>You will callibrate for several prompts and then roam.</ div>''')
186
+ with gr.Row(elem_id='output-image'):
187
+ img = gr.Image(interactive=False, elem_id='output-image',)
188
+ with gr.Row(equal_height=True):
189
+ b3 = gr.Button(value='Dislike', interactive=False,)
190
+ b2 = gr.Button(value='Neither', interactive=False,)
191
+ b1 = gr.Button(value='Like', interactive=False,)
192
+ b1.click(
193
+ choose,
194
+ [b1],
195
+ [img]
196
+ )
197
+ b2.click(
198
+ choose,
199
+ [b2],
200
+ [img]
201
+ )
202
+ b3.click(
203
+ choose,
204
+ [b3],
205
+ [img]
206
+ )
207
+ with gr.Row():
208
+ b4 = gr.Button(value='Start')
209
+ b4.click(start,
210
+ [b4],
211
+ [b1, b2, b3, b4, img,])
212
+
213
+ demo.launch() # Share your demo with just 1 extra parameter πŸš€