Spaces:
Sleeping
Sleeping
rynmurdock
commited on
Commit
•
f824106
1
Parent(s):
ba0dc8e
prompts are visible
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ torch.set_float32_matmul_precision('high')
|
|
15 |
import random
|
16 |
import time
|
17 |
|
18 |
-
# TODO put back
|
19 |
import spaces
|
20 |
from urllib.request import urlopen
|
21 |
|
@@ -54,8 +53,6 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, times
|
|
54 |
|
55 |
pipe.to(device=DEVICE)
|
56 |
|
57 |
-
|
58 |
-
# TODO put back
|
59 |
@spaces.GPU
|
60 |
def compile_em():
|
61 |
pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead')
|
@@ -88,7 +85,7 @@ class BottleneckT5Autoencoder:
|
|
88 |
encode_only=True,
|
89 |
)
|
90 |
|
91 |
-
def generate_from_latent(self, latent: torch.FloatTensor, max_length=
|
92 |
dummy_text = '.'
|
93 |
dummy = self.embed(dummy_text)
|
94 |
perturb_vector = latent - dummy
|
@@ -111,7 +108,6 @@ autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t
|
|
111 |
compile_em()
|
112 |
#######################
|
113 |
|
114 |
-
# TODO put back
|
115 |
@spaces.GPU
|
116 |
def generate(prompt, in_embs=None,):
|
117 |
if prompt != '':
|
@@ -121,11 +117,10 @@ def generate(prompt, in_embs=None,):
|
|
121 |
else:
|
122 |
print('From embeds.')
|
123 |
in_embs = in_embs / in_embs.abs().max() * .15
|
124 |
-
text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.
|
125 |
return text, in_embs.to('cpu')
|
126 |
|
127 |
|
128 |
-
# TODO put back
|
129 |
@spaces.GPU
|
130 |
def predict(
|
131 |
prompt,
|
@@ -233,7 +228,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
233 |
im_emb = autoencoder.embed(prompt)
|
234 |
embs.append(im_emb)
|
235 |
img_embs.append(img_emb)
|
236 |
-
return image, embs, img_embs, ys, calibrate_prompts
|
237 |
else:
|
238 |
print('######### Roaming #########')
|
239 |
|
@@ -264,7 +259,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
264 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
265 |
img_embs.append(img_emb)
|
266 |
|
267 |
-
return image, embs, img_embs, ys, calibrate_prompts
|
268 |
|
269 |
|
270 |
|
@@ -275,7 +270,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
|
|
275 |
|
276 |
|
277 |
def start(_, embs, img_embs, ys, calibrate_prompts):
|
278 |
-
image, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
|
279 |
return [
|
280 |
gr.Button(value='Like (L)', interactive=True),
|
281 |
gr.Button(value='Neither (Space)', interactive=True),
|
@@ -285,7 +280,8 @@ def start(_, embs, img_embs, ys, calibrate_prompts):
|
|
285 |
embs,
|
286 |
img_embs,
|
287 |
ys,
|
288 |
-
calibrate_prompts
|
|
|
289 |
]
|
290 |
|
291 |
|
@@ -295,8 +291,8 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
|
|
295 |
elif choice == 'Neither (Space)':
|
296 |
_ = embs.pop(-1)
|
297 |
_ = img_embs.pop(-1)
|
298 |
-
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
|
299 |
-
return img, embs, img_embs, ys, calibrate_prompts
|
300 |
else:
|
301 |
choice = 0
|
302 |
|
@@ -305,8 +301,8 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
|
|
305 |
choice = 0
|
306 |
|
307 |
ys.append(choice)
|
308 |
-
img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
|
309 |
-
return img, embs, img_embs, ys, calibrate_prompts
|
310 |
|
311 |
css = '''.gradio-container{max-width: 700px !important}
|
312 |
#description{text-align: center}
|
@@ -374,7 +370,8 @@ with gr.Blocks(css=css, head=js_head) as demo:
|
|
374 |
'a sketch of an impressive mountain by da vinci',
|
375 |
'a watercolor painting: the octopus writhes',
|
376 |
])
|
377 |
-
|
|
|
378 |
with gr.Row(elem_id='output-image'):
|
379 |
img = gr.Image(interactive=False, elem_id='output-image', width=700)
|
380 |
with gr.Row(equal_height=True):
|
@@ -384,23 +381,23 @@ with gr.Blocks(css=css, head=js_head) as demo:
|
|
384 |
b1.click(
|
385 |
choose,
|
386 |
[img, b1, embs, img_embs, ys, calibrate_prompts],
|
387 |
-
[img, embs, img_embs, ys, calibrate_prompts]
|
388 |
)
|
389 |
b2.click(
|
390 |
choose,
|
391 |
[img, b2, embs, img_embs, ys, calibrate_prompts],
|
392 |
-
[img, embs, img_embs, ys, calibrate_prompts]
|
393 |
)
|
394 |
b3.click(
|
395 |
choose,
|
396 |
[img, b3, embs, img_embs, ys, calibrate_prompts],
|
397 |
-
[img, embs, img_embs, ys, calibrate_prompts]
|
398 |
)
|
399 |
with gr.Row():
|
400 |
b4 = gr.Button(value='Start')
|
401 |
b4.click(start,
|
402 |
[b4, embs, img_embs, ys, calibrate_prompts],
|
403 |
-
[b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts])
|
404 |
with gr.Row():
|
405 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
|
406 |
<div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating.
|
|
|
15 |
import random
|
16 |
import time
|
17 |
|
|
|
18 |
import spaces
|
19 |
from urllib.request import urlopen
|
20 |
|
|
|
53 |
|
54 |
pipe.to(device=DEVICE)
|
55 |
|
|
|
|
|
56 |
@spaces.GPU
|
57 |
def compile_em():
|
58 |
pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead')
|
|
|
85 |
encode_only=True,
|
86 |
)
|
87 |
|
88 |
+
def generate_from_latent(self, latent: torch.FloatTensor, max_length=20, temperature=1., top_p=.8, min_new_tokens=30) -> str:
|
89 |
dummy_text = '.'
|
90 |
dummy = self.embed(dummy_text)
|
91 |
perturb_vector = latent - dummy
|
|
|
108 |
compile_em()
|
109 |
#######################
|
110 |
|
|
|
111 |
@spaces.GPU
|
112 |
def generate(prompt, in_embs=None,):
|
113 |
if prompt != '':
|
|
|
117 |
else:
|
118 |
print('From embeds.')
|
119 |
in_embs = in_embs / in_embs.abs().max() * .15
|
120 |
+
text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.3, top_p=.99, min_new_tokens=5)
|
121 |
return text, in_embs.to('cpu')
|
122 |
|
123 |
|
|
|
124 |
@spaces.GPU
|
125 |
def predict(
|
126 |
prompt,
|
|
|
228 |
im_emb = autoencoder.embed(prompt)
|
229 |
embs.append(im_emb)
|
230 |
img_embs.append(img_emb)
|
231 |
+
return image, embs, img_embs, ys, calibrate_prompts, prompt
|
232 |
else:
|
233 |
print('######### Roaming #########')
|
234 |
|
|
|
259 |
image, img_emb = predict(prompt, im_emb=img_emb)
|
260 |
img_embs.append(img_emb)
|
261 |
|
262 |
+
return image, embs, img_embs, ys, calibrate_prompts, prompt
|
263 |
|
264 |
|
265 |
|
|
|
270 |
|
271 |
|
272 |
def start(_, embs, img_embs, ys, calibrate_prompts):
|
273 |
+
image, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
|
274 |
return [
|
275 |
gr.Button(value='Like (L)', interactive=True),
|
276 |
gr.Button(value='Neither (Space)', interactive=True),
|
|
|
280 |
embs,
|
281 |
img_embs,
|
282 |
ys,
|
283 |
+
calibrate_prompts,
|
284 |
+
prompt
|
285 |
]
|
286 |
|
287 |
|
|
|
291 |
elif choice == 'Neither (Space)':
|
292 |
_ = embs.pop(-1)
|
293 |
_ = img_embs.pop(-1)
|
294 |
+
img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
|
295 |
+
return img, embs, img_embs, ys, calibrate_prompts, prompt
|
296 |
else:
|
297 |
choice = 0
|
298 |
|
|
|
301 |
choice = 0
|
302 |
|
303 |
ys.append(choice)
|
304 |
+
img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
|
305 |
+
return img, embs, img_embs, ys, calibrate_prompts, prompt
|
306 |
|
307 |
css = '''.gradio-container{max-width: 700px !important}
|
308 |
#description{text-align: center}
|
|
|
370 |
'a sketch of an impressive mountain by da vinci',
|
371 |
'a watercolor painting: the octopus writhes',
|
372 |
])
|
373 |
+
with gr.Row():
|
374 |
+
prompt = gr.Textbox(interactive=False, elem_id="text")
|
375 |
with gr.Row(elem_id='output-image'):
|
376 |
img = gr.Image(interactive=False, elem_id='output-image', width=700)
|
377 |
with gr.Row(equal_height=True):
|
|
|
381 |
b1.click(
|
382 |
choose,
|
383 |
[img, b1, embs, img_embs, ys, calibrate_prompts],
|
384 |
+
[img, embs, img_embs, ys, calibrate_prompts, prompt]
|
385 |
)
|
386 |
b2.click(
|
387 |
choose,
|
388 |
[img, b2, embs, img_embs, ys, calibrate_prompts],
|
389 |
+
[img, embs, img_embs, ys, calibrate_prompts, prompt]
|
390 |
)
|
391 |
b3.click(
|
392 |
choose,
|
393 |
[img, b3, embs, img_embs, ys, calibrate_prompts],
|
394 |
+
[img, embs, img_embs, ys, calibrate_prompts, prompt]
|
395 |
)
|
396 |
with gr.Row():
|
397 |
b4 = gr.Button(value='Start')
|
398 |
b4.click(start,
|
399 |
[b4, embs, img_embs, ys, calibrate_prompts],
|
400 |
+
[b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts, prompt])
|
401 |
with gr.Row():
|
402 |
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
|
403 |
<div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating.
|