rynmurdock commited on
Commit
f824106
1 Parent(s): ba0dc8e

prompts are visible

Browse files
Files changed (1) hide show
  1. app.py +17 -20
app.py CHANGED
@@ -15,7 +15,6 @@ torch.set_float32_matmul_precision('high')
15
  import random
16
  import time
17
 
18
- # TODO put back
19
  import spaces
20
  from urllib.request import urlopen
21
 
@@ -54,8 +53,6 @@ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, times
54
 
55
  pipe.to(device=DEVICE)
56
 
57
-
58
- # TODO put back
59
  @spaces.GPU
60
  def compile_em():
61
  pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead')
@@ -88,7 +85,7 @@ class BottleneckT5Autoencoder:
88
  encode_only=True,
89
  )
90
 
91
- def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, min_new_tokens=30) -> str:
92
  dummy_text = '.'
93
  dummy = self.embed(dummy_text)
94
  perturb_vector = latent - dummy
@@ -111,7 +108,6 @@ autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t
111
  compile_em()
112
  #######################
113
 
114
- # TODO put back
115
  @spaces.GPU
116
  def generate(prompt, in_embs=None,):
117
  if prompt != '':
@@ -121,11 +117,10 @@ def generate(prompt, in_embs=None,):
121
  else:
122
  print('From embeds.')
123
  in_embs = in_embs / in_embs.abs().max() * .15
124
- text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.8, top_p=.94, min_new_tokens=5)
125
  return text, in_embs.to('cpu')
126
 
127
 
128
- # TODO put back
129
  @spaces.GPU
130
  def predict(
131
  prompt,
@@ -233,7 +228,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
233
  im_emb = autoencoder.embed(prompt)
234
  embs.append(im_emb)
235
  img_embs.append(img_emb)
236
- return image, embs, img_embs, ys, calibrate_prompts
237
  else:
238
  print('######### Roaming #########')
239
 
@@ -264,7 +259,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
264
  image, img_emb = predict(prompt, im_emb=img_emb)
265
  img_embs.append(img_emb)
266
 
267
- return image, embs, img_embs, ys, calibrate_prompts
268
 
269
 
270
 
@@ -275,7 +270,7 @@ def next_image(embs, img_embs, ys, calibrate_prompts):
275
 
276
 
277
  def start(_, embs, img_embs, ys, calibrate_prompts):
278
- image, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
279
  return [
280
  gr.Button(value='Like (L)', interactive=True),
281
  gr.Button(value='Neither (Space)', interactive=True),
@@ -285,7 +280,8 @@ def start(_, embs, img_embs, ys, calibrate_prompts):
285
  embs,
286
  img_embs,
287
  ys,
288
- calibrate_prompts
 
289
  ]
290
 
291
 
@@ -295,8 +291,8 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
295
  elif choice == 'Neither (Space)':
296
  _ = embs.pop(-1)
297
  _ = img_embs.pop(-1)
298
- img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
299
- return img, embs, img_embs, ys, calibrate_prompts
300
  else:
301
  choice = 0
302
 
@@ -305,8 +301,8 @@ def choose(img, choice, embs, img_embs, ys, calibrate_prompts):
305
  choice = 0
306
 
307
  ys.append(choice)
308
- img, embs, img_embs, ys, calibrate_prompts = next_image(embs, img_embs, ys, calibrate_prompts)
309
- return img, embs, img_embs, ys, calibrate_prompts
310
 
311
  css = '''.gradio-container{max-width: 700px !important}
312
  #description{text-align: center}
@@ -374,7 +370,8 @@ with gr.Blocks(css=css, head=js_head) as demo:
374
  'a sketch of an impressive mountain by da vinci',
375
  'a watercolor painting: the octopus writhes',
376
  ])
377
-
 
378
  with gr.Row(elem_id='output-image'):
379
  img = gr.Image(interactive=False, elem_id='output-image', width=700)
380
  with gr.Row(equal_height=True):
@@ -384,23 +381,23 @@ with gr.Blocks(css=css, head=js_head) as demo:
384
  b1.click(
385
  choose,
386
  [img, b1, embs, img_embs, ys, calibrate_prompts],
387
- [img, embs, img_embs, ys, calibrate_prompts]
388
  )
389
  b2.click(
390
  choose,
391
  [img, b2, embs, img_embs, ys, calibrate_prompts],
392
- [img, embs, img_embs, ys, calibrate_prompts]
393
  )
394
  b3.click(
395
  choose,
396
  [img, b3, embs, img_embs, ys, calibrate_prompts],
397
- [img, embs, img_embs, ys, calibrate_prompts]
398
  )
399
  with gr.Row():
400
  b4 = gr.Button(value='Start')
401
  b4.click(start,
402
  [b4, embs, img_embs, ys, calibrate_prompts],
403
- [b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts])
404
  with gr.Row():
405
  html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
406
  <div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating.
 
15
  import random
16
  import time
17
 
 
18
  import spaces
19
  from urllib.request import urlopen
20
 
 
53
 
54
  pipe.to(device=DEVICE)
55
 
 
 
56
  @spaces.GPU
57
  def compile_em():
58
  pipe.unet = torch.compile(pipe.unet, mode='reduce-overhead')
 
85
  encode_only=True,
86
  )
87
 
88
+ def generate_from_latent(self, latent: torch.FloatTensor, max_length=20, temperature=1., top_p=.8, min_new_tokens=30) -> str:
89
  dummy_text = '.'
90
  dummy = self.embed(dummy_text)
91
  perturb_vector = latent - dummy
 
108
  compile_em()
109
  #######################
110
 
 
111
  @spaces.GPU
112
  def generate(prompt, in_embs=None,):
113
  if prompt != '':
 
117
  else:
118
  print('From embeds.')
119
  in_embs = in_embs / in_embs.abs().max() * .15
120
+ text = autoencoder.generate_from_latent(in_embs.to('cuda').to(dtype=torch.bfloat16), temperature=.3, top_p=.99, min_new_tokens=5)
121
  return text, in_embs.to('cpu')
122
 
123
 
 
124
  @spaces.GPU
125
  def predict(
126
  prompt,
 
228
  im_emb = autoencoder.embed(prompt)
229
  embs.append(im_emb)
230
  img_embs.append(img_emb)
231
+ return image, embs, img_embs, ys, calibrate_prompts, prompt
232
  else:
233
  print('######### Roaming #########')
234
 
 
259
  image, img_emb = predict(prompt, im_emb=img_emb)
260
  img_embs.append(img_emb)
261
 
262
+ return image, embs, img_embs, ys, calibrate_prompts, prompt
263
 
264
 
265
 
 
270
 
271
 
272
  def start(_, embs, img_embs, ys, calibrate_prompts):
273
+ image, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
274
  return [
275
  gr.Button(value='Like (L)', interactive=True),
276
  gr.Button(value='Neither (Space)', interactive=True),
 
280
  embs,
281
  img_embs,
282
  ys,
283
+ calibrate_prompts,
284
+ prompt
285
  ]
286
 
287
 
 
291
  elif choice == 'Neither (Space)':
292
  _ = embs.pop(-1)
293
  _ = img_embs.pop(-1)
294
+ img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
295
+ return img, embs, img_embs, ys, calibrate_prompts, prompt
296
  else:
297
  choice = 0
298
 
 
301
  choice = 0
302
 
303
  ys.append(choice)
304
+ img, embs, img_embs, ys, calibrate_prompts, prompt = next_image(embs, img_embs, ys, calibrate_prompts)
305
+ return img, embs, img_embs, ys, calibrate_prompts, prompt
306
 
307
  css = '''.gradio-container{max-width: 700px !important}
308
  #description{text-align: center}
 
370
  'a sketch of an impressive mountain by da vinci',
371
  'a watercolor painting: the octopus writhes',
372
  ])
373
+ with gr.Row():
374
+ prompt = gr.Textbox(interactive=False, elem_id="text")
375
  with gr.Row(elem_id='output-image'):
376
  img = gr.Image(interactive=False, elem_id='output-image', width=700)
377
  with gr.Row(equal_height=True):
 
381
  b1.click(
382
  choose,
383
  [img, b1, embs, img_embs, ys, calibrate_prompts],
384
+ [img, embs, img_embs, ys, calibrate_prompts, prompt]
385
  )
386
  b2.click(
387
  choose,
388
  [img, b2, embs, img_embs, ys, calibrate_prompts],
389
+ [img, embs, img_embs, ys, calibrate_prompts, prompt]
390
  )
391
  b3.click(
392
  choose,
393
  [img, b3, embs, img_embs, ys, calibrate_prompts],
394
+ [img, embs, img_embs, ys, calibrate_prompts, prompt]
395
  )
396
  with gr.Row():
397
  b4 = gr.Button(value='Start')
398
  b4.click(start,
399
  [b4, embs, img_embs, ys, calibrate_prompts],
400
+ [b1, b2, b3, b4, img, embs, img_embs, ys, calibrate_prompts, prompt])
401
  with gr.Row():
402
  html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
403
  <div style='text-align:center; font-size:14px'>Note that while the SDXL model is unlikely to produce NSFW images, it still may be possible, and users should avoid NSFW content when rating.