rynmurdock commited on
Commit
67e8481
0 Parent(s):
Files changed (6) hide show
  1. README.md +3 -0
  2. app.py +504 -0
  3. license +4 -0
  4. lightning_app.py +452 -0
  5. safety_checker_improved.py +45 -0
  6. twitter_prompts.csv +72 -0
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Blue Tigers
2
+
3
+ Zahir with movement.
app.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # TODO save & restart from (if it exists) dataframe parquet
5
+ import torch
6
+
7
+ # lol
8
+ DEVICE = 'cuda'
9
+ STEPS = 6
10
+ output_hidden_state = False
11
+ device = "cuda"
12
+ dtype = torch.float16
13
+
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib
16
+
17
+ from sklearn.linear_model import Ridge
18
+ from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
19
+ CompilationConfig)
20
+ config = CompilationConfig.Default()
21
+
22
+ try:
23
+ import triton
24
+ config.enable_triton = True
25
+ except ImportError:
26
+ print('Triton not installed, skip')
27
+ config.enable_cuda_graph = True
28
+ config.enable_jit = True
29
+ config.enable_jit_freeze = True
30
+ config.enable_cnn_optimization = True
31
+ config.preserve_parameters = False
32
+ config.prefer_lowp_gemm = True
33
+
34
+ import imageio
35
+ import gradio as gr
36
+ import numpy as np
37
+ from sklearn.svm import SVC
38
+ from sklearn.inspection import permutation_importance
39
+ from sklearn import preprocessing
40
+ import pandas as pd
41
+ from apscheduler.schedulers.background import BackgroundScheduler
42
+
43
+ import random
44
+ import time
45
+ from PIL import Image
46
+ from safety_checker_improved import maybe_nsfw
47
+
48
+
49
+ torch.set_grad_enabled(False)
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+
53
+ prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
54
+
55
+ import spaces
56
+ prompt_list = [p for p in list(set(
57
+ pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
58
+
59
+ start_time = time.time()
60
+
61
+ ####################### Setup Model
62
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
63
+ from transformers import CLIPTextModel
64
+ from huggingface_hub import hf_hub_download
65
+ from safetensors.torch import load_file
66
+ from PIL import Image
67
+ from transformers import CLIPVisionModelWithProjection
68
+ import uuid
69
+ import av
70
+
71
+ def write_video(file_name, images, fps=17):
72
+ print('Saving')
73
+ container = av.open(file_name, mode="w")
74
+
75
+ stream = container.add_stream("h264", rate=fps)
76
+ # stream.options = {'preset': 'faster'}
77
+ stream.thread_count = 0
78
+ stream.width = 512
79
+ stream.height = 512
80
+ stream.pix_fmt = "yuv420p"
81
+
82
+ for img in images:
83
+ img = np.array(img)
84
+ img = np.round(img).astype(np.uint8)
85
+ frame = av.VideoFrame.from_ndarray(img, format="rgb24")
86
+ for packet in stream.encode(frame):
87
+ container.mux(packet)
88
+ # Flush stream
89
+ for packet in stream.encode():
90
+ container.mux(packet)
91
+ # Close the file
92
+ container.close()
93
+ print('Saved')
94
+
95
+
96
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype).to(DEVICE)
97
+ #vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
98
+
99
+ # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
100
+ # vae = compile_unet(vae, config=config)
101
+
102
+ #finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
103
+ #unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
104
+ #text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
105
+
106
+
107
+ unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype)
108
+ text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder').to(dtype)
109
+
110
+ adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
111
+ pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
112
+ pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
113
+ pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
114
+ pipe.set_adapters(["lcm-lora"], [.9])
115
+ pipe.fuse_lora()
116
+
117
+ #pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
118
+ #pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
119
+ #repo = "ByteDance/AnimateDiff-Lightning"
120
+ #ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
121
+
122
+
123
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
124
+ # This IP adapter improves outputs substantially.
125
+ pipe.set_ip_adapter_scale(.8)
126
+ pipe.unet.fuse_qkv_projections()
127
+ #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
128
+
129
+ #pipe = compile(pipe, config=config)
130
+ pipe.to(device=DEVICE)
131
+ #pipe.unet = torch.compile(pipe.unet)
132
+ #pipe.vae = torch.compile(pipe.vae)
133
+
134
+
135
+ im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
136
+ output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
137
+ leave_im_emb, _ = pipe.encode_image(
138
+ output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
139
+ )
140
+ assert len(output.frames[0]) == 16
141
+ leave_im_emb.detach().to('cpu')
142
+
143
+
144
+ @spaces.GPU()
145
+ def generate(in_im_embs):
146
+ in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
147
+ #im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
148
+
149
+ output = pipe(prompt='a scene', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
150
+
151
+ im_emb, _ = pipe.encode_image(
152
+ output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
153
+ )
154
+ im_emb = im_emb.detach().to('cpu')
155
+
156
+ nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
157
+
158
+ name = str(uuid.uuid4()).replace("-", "")
159
+ path = f"/tmp/{name}.mp4"
160
+
161
+ if nsfw:
162
+ gr.Warning("NSFW content detected.")
163
+ # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
164
+ return None, im_emb
165
+
166
+
167
+ output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
168
+
169
+ write_video(path, output.frames[0])
170
+ return path, im_emb
171
+
172
+
173
+ #######################
174
+
175
+ # TODO add to state instead of shared across all
176
+ glob_idx = 0
177
+
178
+ # TODO
179
+ # We can keep a df of media paths, embeddings, and user ratings.
180
+ # We can drop by lowest user ratings to keep enough RAM available when we get too many rows.
181
+ # We can continuously update by who is most recently active in the background & server as we go, plucking using "has been seen" and similarity
182
+ # to user embeds
183
+
184
+ def get_user_emb(embs, ys):
185
+ # handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
186
+ if len(list(set(ys))) <= 1:
187
+ embs.append(.01*torch.randn(1280))
188
+ embs.append(.01*torch.randn(1280))
189
+ ys.append(0)
190
+ ys.append(1)
191
+ print('Fixing only one feedback class available.\n')
192
+
193
+ indices = list(range(len(embs)))
194
+ # sample only as many negatives as there are positives
195
+ pos_indices = [i for i in indices if ys[i] == 1]
196
+ neg_indices = [i for i in indices if ys[i] == 0]
197
+ #lower = min(len(pos_indices), len(neg_indices))
198
+ #neg_indices = random.sample(neg_indices, lower)
199
+ #pos_indices = random.sample(pos_indices, lower)
200
+ print(len(neg_indices), len(pos_indices))
201
+
202
+
203
+ # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
204
+ # this ends up adding a rating but losing an embedding, it seems.
205
+ # let's take off a rating if so to continue without indexing errors.
206
+ if len(ys) > len(embs):
207
+ print('ys are longer than embs; popping latest rating')
208
+ ys.pop(-1)
209
+
210
+ feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices] + [leave_im_emb.to('cpu').squeeze()]).to('cpu'))
211
+ #scaler = preprocessing.StandardScaler().fit(feature_embs)
212
+ #feature_embs = scaler.transform(feature_embs)
213
+ chosen_y = np.array([ys[i] for i in indices] + [0])
214
+
215
+ print('Gathering coefficients')
216
+ #lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
217
+ lin_class = SVC(max_iter=50000, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
218
+ coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
219
+ coef_ = coef_ / coef_.abs().max() * 3
220
+ print('Gathered')
221
+
222
+ w = 1# if len(embs) % 2 == 0 else 0
223
+ im_emb = w * coef_.to(dtype=dtype)
224
+ return im_emb
225
+
226
+
227
+ def pluck_img(user_id, user_emb):
228
+ print(user_id, 'user_id')
229
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
230
+ rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
231
+ while len(not_rated_rows) == 0:
232
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
233
+ time.sleep(.01)
234
+ # TODO optimize this lol
235
+ best_sim = -100000
236
+ for i in not_rated_rows.iterrows():
237
+ # TODO sloppy .to but it is 3am.
238
+ sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu'))
239
+ if sim > best_sim:
240
+ best_sim = sim
241
+ best_row = i[1]
242
+ img = best_row['paths']
243
+ return img
244
+
245
+
246
+ def background_next_image():
247
+ global prevs_df
248
+
249
+ # only let it get N (maybe 3) ahead of the user
250
+ not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
251
+ rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
252
+ while len(not_rated_rows) > 8 or len(rated_rows) < 4:
253
+ not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
254
+ rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
255
+ time.sleep(.01)
256
+
257
+ print(rated_rows['latest_user_to_rate'])
258
+ latest_user_id = rated_rows.iloc[-1]['latest_user_to_rate']
259
+ rated_rows = prevs_df[[i[1]['user:rating'].get(latest_user_id, None) is not None for i in prevs_df.iterrows()]]
260
+
261
+ print(latest_user_id)
262
+ embs, ys = pluck_embs_ys(latest_user_id)
263
+
264
+ user_emb = get_user_emb(embs, ys)
265
+ img, embs = generate(user_emb)
266
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
267
+ tmp_df['paths'] = [img]
268
+ tmp_df['embeddings'] = [embs]
269
+ tmp_df['user:rating'] = [{' ': ' '}]
270
+ prevs_df = pd.concat((prevs_df, tmp_df))
271
+ # we can free up storage by deleting the image
272
+ if len(prevs_df) > 50:
273
+ oldest_path = prevs_df.iloc[0]['paths']
274
+ if os.path.isfile(oldest_path):
275
+ os.remove(oldest_path)
276
+ else:
277
+ # If it fails, inform the user.
278
+ print("Error: %s file not found" % oldest_path)
279
+ # only keep 50 images & embeddings & ips, then remove oldest
280
+ prevs_df = prevs_df.iloc[1:]
281
+
282
+
283
+ def pluck_embs_ys(user_id):
284
+ rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
285
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
286
+ while len(not_rated_rows) == 0:
287
+ not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
288
+ rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
289
+ time.sleep(.01)
290
+
291
+ embs = rated_rows['embeddings'].to_list()
292
+ ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
293
+ print('embs', 'ys', embs, ys)
294
+ return embs, ys
295
+
296
+ def next_image(calibrate_prompts, user_id):
297
+ global glob_idx
298
+ glob_idx = glob_idx + 1
299
+
300
+ with torch.no_grad():
301
+ if len(calibrate_prompts) > 0:
302
+ print('######### Calibrating with sample media #########')
303
+ cal_video = calibrate_prompts.pop(0)
304
+ image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
305
+
306
+ return image, calibrate_prompts
307
+ else:
308
+ print('######### Roaming #########')
309
+ embs, ys = pluck_embs_ys(user_id)
310
+ user_emb = get_user_emb(embs, ys)
311
+ image = pluck_img(user_id, user_emb)
312
+ return image, calibrate_prompts
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+
321
+
322
+ def start(_, calibrate_prompts, user_id, request: gr.Request):
323
+ image, calibrate_prompts = next_image(calibrate_prompts, user_id)
324
+ return [
325
+ gr.Button(value='Like (L)', interactive=True),
326
+ gr.Button(value='Neither (Space)', interactive=True),
327
+ gr.Button(value='Dislike (A)', interactive=True),
328
+ gr.Button(value='Start', interactive=False),
329
+ image,
330
+ calibrate_prompts
331
+ ]
332
+
333
+
334
+ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
335
+ global prevs_df
336
+
337
+
338
+ if choice == 'Like (L)':
339
+ choice = 1
340
+ elif choice == 'Neither (Space)':
341
+ img, calibrate_prompts = next_image(calibrate_prompts, user_id)
342
+ return img, calibrate_prompts
343
+ else:
344
+ choice = 0
345
+
346
+ # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
347
+ # TODO skip allowing rating & just continue
348
+ if img == None:
349
+ print('NSFW -- choice is disliked')
350
+ choice = 0
351
+
352
+ # TODO clean up
353
+ old_d = prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'user:rating'][0]
354
+ old_d[user_id] = choice
355
+ prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'user:rating'][0] = old_d
356
+ prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'latest_user_to_rate'] = [user_id]
357
+ print('full_df, prevs_df', prevs_df, prevs_df['latest_user_to_rate'])
358
+
359
+ img, calibrate_prompts = next_image(calibrate_prompts, user_id)
360
+ return img, calibrate_prompts
361
+
362
+ css = '''.gradio-container{max-width: 700px !important}
363
+ #description{text-align: center}
364
+ #description h1, #description h3{display: block}
365
+ #description p{margin-top: 0}
366
+ .fade-in-out {animation: fadeInOut 3s forwards}
367
+ @keyframes fadeInOut {
368
+ 0% {
369
+ background: var(--bg-color);
370
+ }
371
+ 100% {
372
+ background: var(--button-secondary-background-fill);
373
+ }
374
+ }
375
+ '''
376
+ js_head = '''
377
+ <script>
378
+ document.addEventListener('keydown', function(event) {
379
+ if (event.key === 'a' || event.key === 'A') {
380
+ // Trigger click on 'dislike' if 'A' is pressed
381
+ document.getElementById('dislike').click();
382
+ } else if (event.key === ' ' || event.keyCode === 32) {
383
+ // Trigger click on 'neither' if Spacebar is pressed
384
+ document.getElementById('neither').click();
385
+ } else if (event.key === 'l' || event.key === 'L') {
386
+ // Trigger click on 'like' if 'L' is pressed
387
+ document.getElementById('like').click();
388
+ }
389
+ });
390
+ function fadeInOut(button, color) {
391
+ button.style.setProperty('--bg-color', color);
392
+ button.classList.remove('fade-in-out');
393
+ void button.offsetWidth; // This line forces a repaint by accessing a DOM property
394
+
395
+ button.classList.add('fade-in-out');
396
+ button.addEventListener('animationend', () => {
397
+ button.classList.remove('fade-in-out'); // Reset the animation state
398
+ }, {once: true});
399
+ }
400
+ document.body.addEventListener('click', function(event) {
401
+ const target = event.target;
402
+ if (target.id === 'dislike') {
403
+ fadeInOut(target, '#ff1717');
404
+ } else if (target.id === 'like') {
405
+ fadeInOut(target, '#006500');
406
+ } else if (target.id === 'neither') {
407
+ fadeInOut(target, '#cccccc');
408
+ }
409
+ });
410
+
411
+ </script>
412
+ '''
413
+
414
+ with gr.Blocks(css=css, head=js_head) as demo:
415
+ gr.Markdown('''# Blue Tigers
416
+ ### Generative Recommenders for Exporation of Video
417
+
418
+ Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
419
+ ''', elem_id="description")
420
+ user_id = gr.State(int(torch.randint(2**6, (1,))[0]))
421
+ calibrate_prompts = gr.State([
422
+ './first.mp4',
423
+ './second.mp4',
424
+ './third.mp4',
425
+ './fourth.mp4',
426
+ './fifth.mp4',
427
+ './sixth.mp4',
428
+ './seventh.mp4',
429
+ ])
430
+ def l():
431
+ return None
432
+
433
+ with gr.Row(elem_id='output-image'):
434
+ img = gr.Video(
435
+ label='Lightning',
436
+ autoplay=True,
437
+ interactive=False,
438
+ height=512,
439
+ width=512,
440
+ include_audio=False,
441
+ elem_id="video_output"
442
+ )
443
+ img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
444
+ with gr.Row(equal_height=True):
445
+ b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
446
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
447
+ b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
448
+ b1.click(
449
+ choose,
450
+ [img, b1, calibrate_prompts, user_id],
451
+ [img, calibrate_prompts],
452
+ )
453
+ b2.click(
454
+ choose,
455
+ [img, b2, calibrate_prompts, user_id],
456
+ [img, calibrate_prompts],
457
+ )
458
+ b3.click(
459
+ choose,
460
+ [img, b3, calibrate_prompts, user_id],
461
+ [img, calibrate_prompts],
462
+ )
463
+ with gr.Row():
464
+ b4 = gr.Button(value='Start')
465
+ b4.click(start,
466
+ [b4, calibrate_prompts, user_id],
467
+ [b1, b2, b3, b4, img, calibrate_prompts]
468
+ )
469
+ with gr.Row():
470
+ html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
471
+ <div style='text-align:center; font-size:14px'>Note that while the AnimateLCM model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
472
+ </ div>
473
+ <br><br>
474
+ <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
475
+ </ div>''')
476
+
477
+ scheduler = BackgroundScheduler()
478
+ scheduler.add_job(func=background_next_image, trigger="interval", seconds=1)
479
+ scheduler.start()
480
+
481
+ # prep our calibration prompts
482
+ for im in [
483
+ './first.mp4',
484
+ './second.mp4',
485
+ './third.mp4',
486
+ './fourth.mp4',
487
+ './fifth.mp4',
488
+ './sixth.mp4',
489
+ './seventh.mp4',
490
+ ]:
491
+ tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
492
+ tmp_df['paths'] = [im]
493
+ image = list(imageio.imiter(im))
494
+ image = image[len(image)//2]
495
+ im_emb, _ = pipe.encode_image(
496
+ image, DEVICE, 1, output_hidden_state
497
+ )
498
+
499
+ tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
500
+ tmp_df['user:rating'] = [{' ': ' '}]
501
+ prevs_df = pd.concat((prevs_df, tmp_df))
502
+
503
+
504
+ demo.launch(share=True)
license ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ You may use this as you please iff you:
2
+ do not hold the authors liable for any issues you may encounter;
3
+ provide attribution by prominently linking to https://rynmurdock.github.io/posts/2024/3/generative_recomenders/ if you redistribute this code or use it within a product;
4
+ include the word "Tiger" within the name any products downstream of this.
lightning_app.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+ # lol
5
+ sidel = 512
6
+ DEVICE = 'cuda'
7
+ STEPS = 4
8
+ output_hidden_state = False
9
+ device = "cuda"
10
+ dtype = torch.float16
11
+
12
+ import matplotlib.pyplot as plt
13
+ import matplotlib
14
+ matplotlib.use('TkAgg')
15
+
16
+ from sklearn.linear_model import LinearRegression
17
+ from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
18
+ CompilationConfig)
19
+ config = CompilationConfig.Default()
20
+
21
+ try:
22
+ import triton
23
+ config.enable_triton = True
24
+ except ImportError:
25
+ print('Triton not installed, skip')
26
+ config.enable_cuda_graph = True
27
+
28
+ config.enable_jit = True
29
+ config.enable_jit_freeze = True
30
+
31
+ config.enable_cnn_optimization = True
32
+ config.preserve_parameters = False
33
+ config.prefer_lowp_gemm = True
34
+
35
+ import imageio
36
+ import gradio as gr
37
+ import numpy as np
38
+ from sklearn.svm import SVC
39
+ from sklearn.inspection import permutation_importance
40
+ from sklearn import preprocessing
41
+ import pandas as pd
42
+
43
+ import random
44
+ import time
45
+ from PIL import Image
46
+ from safety_checker_improved import maybe_nsfw
47
+
48
+
49
+ torch.set_grad_enabled(False)
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+
53
+ # TODO put back?
54
+ # import spaces
55
+
56
+ prompt_list = [p for p in list(set(
57
+ pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
58
+
59
+ start_time = time.time()
60
+
61
+ ####################### Setup Model
62
+ from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, ConsistencyDecoderVAE, AutoencoderTiny
63
+ from hyper_tile import split_attention, flush
64
+ from huggingface_hub import hf_hub_download
65
+ from safetensors.torch import load_file
66
+ from PIL import Image
67
+ from transformers import CLIPVisionModelWithProjection
68
+ import uuid
69
+ import av
70
+
71
+ def write_video(file_name, images, fps=10):
72
+ print('Saving')
73
+ container = av.open(file_name, mode="w")
74
+
75
+ stream = container.add_stream("h264", rate=fps)
76
+ stream.width = sidel
77
+ stream.height = sidel
78
+ stream.pix_fmt = "yuv420p"
79
+
80
+ for img in images:
81
+ img = np.array(img)
82
+ img = np.round(img).astype(np.uint8)
83
+ frame = av.VideoFrame.from_ndarray(img, format="rgb24")
84
+ for packet in stream.encode(frame):
85
+ container.mux(packet)
86
+ # Flush stream
87
+ for packet in stream.encode():
88
+ container.mux(packet)
89
+ # Close the file
90
+ container.close()
91
+ print('Saved')
92
+
93
+ bases = {
94
+ #"basem": "emilianJR/epiCRealism"
95
+ #SG161222/Realistic_Vision_V6.0_B1_noVAE
96
+ #runwayml/stable-diffusion-v1-5
97
+ #frankjoshua/realisticVisionV51_v51VAE
98
+ #Lykon/dreamshaper-7
99
+ }
100
+
101
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype).to(DEVICE)
102
+ vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
103
+
104
+ # vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
105
+ # vae = compile_unet(vae, config=config)
106
+
107
+ #adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
108
+ #pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype)
109
+ #pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
110
+ #pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
111
+ #pipe.set_adapters(["lcm-lora"], [1])
112
+ #pipe.fuse_lora()
113
+
114
+ pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder, vae=vae)
115
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
116
+ repo = "ByteDance/AnimateDiff-Lightning"
117
+ ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
118
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device='cpu'), strict=False)
119
+
120
+
121
+ pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", map_location='cpu')
122
+ pipe.set_ip_adapter_scale(.8)
123
+ # pipe.unet.fuse_qkv_projections()
124
+ #pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
125
+
126
+ pipe = compile(pipe, config=config)
127
+ pipe.to(device=DEVICE)
128
+
129
+
130
+ # THIS WOULD NEED PATCHING TODO
131
+ with split_attention(pipe.vae, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
132
+ # ! Change the tile_size and disable to see their effects
133
+ with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
134
+ im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
135
+ output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
136
+ leave_im_emb, _ = pipe.encode_image(
137
+ output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
138
+ )
139
+ assert len(output.frames[0]) == 16
140
+ leave_im_emb.to('cpu')
141
+
142
+
143
+ # TODO put back
144
+ # @spaces.GPU()
145
+ def generate(prompt, in_im_embs=None, base='basem'):
146
+
147
+ if in_im_embs == None:
148
+ in_im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
149
+ #in_im_embs = in_im_embs / torch.norm(in_im_embs)
150
+ else:
151
+ in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
152
+ #im_embs = torch.cat((torch.zeros(1, 1024, device=DEVICE, dtype=dtype), in_im_embs), 0)
153
+
154
+ with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
155
+ # ! Change the tile_size and disable to see their effects
156
+ with split_attention(pipe.vae, tile_size=128, disable=False, aspect_ratio=1):
157
+ output = pipe(prompt=prompt, guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
158
+
159
+ im_emb, _ = pipe.encode_image(
160
+ output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
161
+ )
162
+
163
+ nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
164
+
165
+ name = str(uuid.uuid4()).replace("-", "")
166
+ path = f"/tmp/{name}.mp4"
167
+
168
+ if nsfw:
169
+ gr.Warning("NSFW content detected.")
170
+ # TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
171
+ return None, im_emb
172
+
173
+ plt.close('all')
174
+ plt.hist(np.array(im_emb.to('cpu')).flatten(), bins=5)
175
+ plt.savefig('real_im_emb_plot.jpg')
176
+
177
+ write_video(path, output.frames[0])
178
+ return path, im_emb.to('cpu')
179
+
180
+
181
+ #######################
182
+
183
+ # TODO add to state instead of shared across all
184
+ glob_idx = 0
185
+
186
+ def next_image(embs, ys, calibrate_prompts):
187
+ global glob_idx
188
+ glob_idx = glob_idx + 1
189
+
190
+ with torch.no_grad():
191
+ if len(calibrate_prompts) > 0:
192
+ print('######### Calibrating with sample prompts #########')
193
+ prompt = calibrate_prompts.pop(0)
194
+ print(prompt)
195
+ image, img_embs = generate(prompt)
196
+ embs += img_embs
197
+ print(len(embs))
198
+ return image, embs, ys, calibrate_prompts
199
+ else:
200
+ print('######### Roaming #########')
201
+
202
+ # sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
203
+ # could take a sample < len(embs)
204
+ #n_to_choose = max(int((len(embs))), 2)
205
+ #indices = random.sample(range(len(embs)), n_to_choose)
206
+
207
+ # sample only as many negatives as there are positives
208
+ #pos_indices = [i for i in indices if ys[i] == 1]
209
+ #neg_indices = [i for i in indices if ys[i] == 0]
210
+ #lower = min(len(pos_indices), len(neg_indices))
211
+ #neg_indices = random.sample(neg_indices, lower)
212
+ #pos_indices = random.sample(pos_indices, lower)
213
+ #indices = neg_indices + pos_indices
214
+
215
+ pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
216
+ neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
217
+
218
+ # the embs & ys stay tied by index but we shuffle to drop randomly
219
+ random.shuffle(pos_indices)
220
+ random.shuffle(neg_indices)
221
+
222
+ #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
223
+ # pos_indices = pos_indices[32:]
224
+ if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 120/16:
225
+ pos_indices = pos_indices[1:]
226
+ if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 200/16:
227
+ neg_indices = neg_indices[2:]
228
+
229
+
230
+ print(len(pos_indices), len(neg_indices))
231
+ indices = pos_indices + neg_indices
232
+
233
+ embs = [embs[i] for i in indices]
234
+ ys = [ys[i] for i in indices]
235
+ indices = list(range(len(embs)))
236
+
237
+
238
+ # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
239
+ if len(list(set(ys))) <= 1:
240
+ embs.append(.01*torch.randn(1024))
241
+ embs.append(.01*torch.randn(1024))
242
+ ys.append(0)
243
+ ys.append(1)
244
+
245
+
246
+ # also add the latest 0 and the latest 1
247
+ has_0 = False
248
+ has_1 = False
249
+ for i in reversed(range(len(ys))):
250
+ if ys[i] == 0 and has_0 == False:
251
+ indices.append(i)
252
+ has_0 = True
253
+ elif ys[i] == 1 and has_1 == False:
254
+ indices.append(i)
255
+ has_1 = True
256
+ if has_0 and has_1:
257
+ break
258
+
259
+ # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
260
+ # this ends up adding a rating but losing an embedding, it seems.
261
+ # let's take off a rating if so to continue without indexing errors.
262
+ if len(ys) > len(embs):
263
+ print('ys are longer than embs; popping latest rating')
264
+ ys.pop(-1)
265
+
266
+ feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices] + [leave_im_emb[0].to('cpu')]).to('cpu'))
267
+ scaler = preprocessing.StandardScaler().fit(feature_embs)
268
+ feature_embs = scaler.transform(feature_embs)
269
+ chosen_y = np.array([ys[i] for i in indices] + [0])
270
+
271
+ print('Gathering coefficients')
272
+ #lin_class = LinearRegression(fit_intercept=False).fit(feature_embs, chosen_y)
273
+ lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=1).fit(feature_embs, chosen_y)
274
+ coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
275
+ coef_ = coef_ / coef_.abs().max() * 3
276
+ print(coef_.shape, 'COEF')
277
+
278
+ plt.close('all')
279
+ plt.hist(np.array(coef_).flatten(), bins=5)
280
+ plt.savefig('plot.jpg')
281
+ print(coef_)
282
+ print('Gathered')
283
+
284
+ rng_prompt = random.choice(prompt_list)
285
+ w = 1# if len(embs) % 2 == 0 else 0
286
+ im_emb = w * coef_.to(dtype=dtype)
287
+
288
+ prompt= 'the scene' if glob_idx % 2 == 0 else rng_prompt
289
+ print(prompt)
290
+ image, im_emb = generate(prompt, im_emb)
291
+ embs += im_emb
292
+
293
+ if len(embs) > 700/16:
294
+ embs = embs[1:]
295
+ ys = ys[1:]
296
+
297
+ return image, embs, ys, calibrate_prompts
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+ def start(_, embs, ys, calibrate_prompts):
308
+ image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
309
+ return [
310
+ gr.Button(value='Like (L)', interactive=True),
311
+ gr.Button(value='Neither (Space)', interactive=True),
312
+ gr.Button(value='Dislike (A)', interactive=True),
313
+ gr.Button(value='Start', interactive=False),
314
+ image,
315
+ embs,
316
+ ys,
317
+ calibrate_prompts
318
+ ]
319
+
320
+
321
+ def choose(img, choice, embs, ys, calibrate_prompts):
322
+ if choice == 'Like (L)':
323
+ choice = 1
324
+ elif choice == 'Neither (Space)':
325
+ embs = embs[:-1]
326
+ img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
327
+ return img, embs, ys, calibrate_prompts
328
+ else:
329
+ choice = 0
330
+
331
+ # if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
332
+ # TODO skip allowing rating
333
+ if img == None:
334
+ print('NSFW -- choice is disliked')
335
+ choice = 0
336
+
337
+ ys += [choice]*1
338
+ img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
339
+ return img, embs, ys, calibrate_prompts
340
+
341
+ css = '''.gradio-container{max-width: 700px !important}
342
+ #description{text-align: center}
343
+ #description h1, #description h3{display: block}
344
+ #description p{margin-top: 0}
345
+ .fade-in-out {animation: fadeInOut 3s forwards}
346
+ @keyframes fadeInOut {
347
+ 0% {
348
+ background: var(--bg-color);
349
+ }
350
+ 100% {
351
+ background: var(--button-secondary-background-fill);
352
+ }
353
+ }
354
+ '''
355
+ js_head = '''
356
+ <script>
357
+ document.addEventListener('keydown', function(event) {
358
+ if (event.key === 'a' || event.key === 'A') {
359
+ // Trigger click on 'dislike' if 'A' is pressed
360
+ document.getElementById('dislike').click();
361
+ } else if (event.key === ' ' || event.keyCode === 32) {
362
+ // Trigger click on 'neither' if Spacebar is pressed
363
+ document.getElementById('neither').click();
364
+ } else if (event.key === 'l' || event.key === 'L') {
365
+ // Trigger click on 'like' if 'L' is pressed
366
+ document.getElementById('like').click();
367
+ }
368
+ });
369
+ function fadeInOut(button, color) {
370
+ button.style.setProperty('--bg-color', color);
371
+ button.classList.remove('fade-in-out');
372
+ void button.offsetWidth; // This line forces a repaint by accessing a DOM property
373
+
374
+ button.classList.add('fade-in-out');
375
+ button.addEventListener('animationend', () => {
376
+ button.classList.remove('fade-in-out'); // Reset the animation state
377
+ }, {once: true});
378
+ }
379
+ document.body.addEventListener('click', function(event) {
380
+ const target = event.target;
381
+ if (target.id === 'dislike') {
382
+ fadeInOut(target, '#ff1717');
383
+ } else if (target.id === 'like') {
384
+ fadeInOut(target, '#006500');
385
+ } else if (target.id === 'neither') {
386
+ fadeInOut(target, '#cccccc');
387
+ }
388
+ });
389
+
390
+ </script>
391
+ '''
392
+
393
+ with gr.Blocks(css=css, head=js_head) as demo:
394
+ gr.Markdown('''### Blue Tigers: Generative Recommenders for Exporation of Video
395
+ Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
396
+ ''', elem_id="description")
397
+ embs = gr.State([])
398
+ ys = gr.State([])
399
+ calibrate_prompts = gr.State([
400
+ 'the moon is melting into my glass of tea',
401
+ 'a sea slug -- pair of claws scuttling -- jelly fish glowing',
402
+ 'an adorable creature. It may be a goblin or a pig or a slug.',
403
+ 'an animation about a gorgeous nebula',
404
+ 'an octopus writhes',
405
+ ])
406
+ def l():
407
+ return None
408
+
409
+ with gr.Row(elem_id='output-image'):
410
+ img = gr.Video(
411
+ label='Lightning',
412
+ autoplay=True,
413
+ interactive=False,
414
+ height=sidel,
415
+ width=sidel,
416
+ include_audio=False,
417
+ elem_id="video_output"
418
+ )
419
+ img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
420
+ with gr.Row(equal_height=True):
421
+ b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
422
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
423
+ b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
424
+ b1.click(
425
+ choose,
426
+ [img, b1, embs, ys, calibrate_prompts],
427
+ [img, embs, ys, calibrate_prompts]
428
+ )
429
+ b2.click(
430
+ choose,
431
+ [img, b2, embs, ys, calibrate_prompts],
432
+ [img, embs, ys, calibrate_prompts]
433
+ )
434
+ b3.click(
435
+ choose,
436
+ [img, b3, embs, ys, calibrate_prompts],
437
+ [img, embs, ys, calibrate_prompts]
438
+ )
439
+ with gr.Row():
440
+ b4 = gr.Button(value='Start')
441
+ b4.click(start,
442
+ [b4, embs, ys, calibrate_prompts],
443
+ [b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
444
+ with gr.Row():
445
+ html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
446
+ <div style='text-align:center; font-size:14px'>Note that while the AnimateDiff-Lightning model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
447
+ </ div>
448
+ <br><br>
449
+ <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
450
+ </ div>''')
451
+
452
+ demo.launch(share=True)
safety_checker_improved.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # TODO required tensorflow==2.14 for me
3
+ # weights from https://github.com/LAION-AI/safety-pipeline/tree/main
4
+ from PIL import Image
5
+ import tensorflow_hub as hub
6
+ import tensorflow
7
+ import numpy as np
8
+ import sys
9
+ sys.path.append('/home/ryn_mote/Misc/generative_recommender/gradio_video/automl/efficientnetv2/')
10
+ import tensorflow as tf
11
+ from tensorflow.keras import mixed_precision
12
+ physical_devices = tf.config.list_physical_devices('GPU')
13
+
14
+ tf.config.experimental.set_memory_growth(
15
+ physical_devices[0], True
16
+ )
17
+
18
+ model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
19
+ # "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).
20
+ # ... we created a manually inspected test set that consists of 4900 samples, that contains images & their captions."
21
+
22
+ # Run prediction
23
+ def maybe_nsfw(pil_image):
24
+ # Run prediction
25
+ imm = tensorflow.image.resize(np.array(pil_image)[:, :, :3], (260, 260))
26
+ imm = (imm / 255)
27
+ pred = model(tensorflow.expand_dims(imm, 0)).numpy()
28
+ probs = tensorflow.math.softmax(pred[0]).numpy()
29
+ print(probs)
30
+ if all([i < .3 for i in probs[[1, 3, 4]]]):
31
+ return False
32
+ return True
33
+
34
+ # pre-initializing prediction
35
+ maybe_nsfw(Image. new("RGB", (260, 260), 255))
36
+ model.load_weights('nsfweffnetv2-b02-3epochs.h5', by_name=True, )
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
twitter_prompts.csv ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,0
2
+ 0,a sunset
3
+ 1,a still life in blue
4
+ 2,last day on earth
5
+ 3,the conch shell
6
+ 4,the winds of change
7
+ 5,a surrealist eye
8
+ 6,a surrealist polaroid photo of an apple
9
+ 7,metaphysics
10
+ 8,the sun is setting into my glass of tea
11
+ 9,the moon at 3am
12
+ 10,a memento mori
13
+ 11,quaking aspen tree
14
+ 12,violets and daffodils
15
+ 13,espresso
16
+ 14,sisyphus
17
+ 15,high windows of stained glass
18
+ 16,a green dog
19
+ 17,an adorable companion; it is a pig
20
+ 18,bird of paradise
21
+ 19,a complex intricate machine
22
+ 20,a white clock
23
+ 21,a film featuring the landscape Salt Lake City Utah
24
+ 22,a creature
25
+ 23,a house set aflame.
26
+ 24,a gorgeous landscape by Cy Twombly
27
+ 25,smoke rises from the caterpillar's hookah
28
+ 26,corvid in red
29
+ 27,Monet's pond
30
+ 28,Genesis
31
+ 29,Death is a black camel that kneels down so we can ride
32
+ 30,a cherry tree made of fractals
33
+ 29,the end of the sidewalk
34
+ 30,a polaroid photo of a bustling city of lights and sky scrapers
35
+ 31,The Fig Tree metaphor
36
+ 32,God killed Van Gogh.
37
+ 33,a cosmic entity alien with four eyes.
38
+ 34,a horse with 128 eyes.
39
+ 35,a being with an infinite set of eyes (it is omniscient)
40
+ 36,A sticky-note magnum opus featuring birds
41
+ 37,Moka Pot
42
+ 38,the moon is a sickle cell
43
+ 39,The Penultimate Supper
44
+ 40,Art
45
+ 41,surrealism
46
+ 42,a god made of wires & dust
47
+ 43,a dandelion blown into the universe
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+