Spaces:
Sleeping
Sleeping
rynmurdock
commited on
Commit
•
67e8481
0
Parent(s):
sig
Browse files- README.md +3 -0
- app.py +504 -0
- license +4 -0
- lightning_app.py +452 -0
- safety_checker_improved.py +45 -0
- twitter_prompts.csv +72 -0
README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# Blue Tigers
|
2 |
+
|
3 |
+
Zahir with movement.
|
app.py
ADDED
@@ -0,0 +1,504 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
# TODO save & restart from (if it exists) dataframe parquet
|
5 |
+
import torch
|
6 |
+
|
7 |
+
# lol
|
8 |
+
DEVICE = 'cuda'
|
9 |
+
STEPS = 6
|
10 |
+
output_hidden_state = False
|
11 |
+
device = "cuda"
|
12 |
+
dtype = torch.float16
|
13 |
+
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
import matplotlib
|
16 |
+
|
17 |
+
from sklearn.linear_model import Ridge
|
18 |
+
from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
|
19 |
+
CompilationConfig)
|
20 |
+
config = CompilationConfig.Default()
|
21 |
+
|
22 |
+
try:
|
23 |
+
import triton
|
24 |
+
config.enable_triton = True
|
25 |
+
except ImportError:
|
26 |
+
print('Triton not installed, skip')
|
27 |
+
config.enable_cuda_graph = True
|
28 |
+
config.enable_jit = True
|
29 |
+
config.enable_jit_freeze = True
|
30 |
+
config.enable_cnn_optimization = True
|
31 |
+
config.preserve_parameters = False
|
32 |
+
config.prefer_lowp_gemm = True
|
33 |
+
|
34 |
+
import imageio
|
35 |
+
import gradio as gr
|
36 |
+
import numpy as np
|
37 |
+
from sklearn.svm import SVC
|
38 |
+
from sklearn.inspection import permutation_importance
|
39 |
+
from sklearn import preprocessing
|
40 |
+
import pandas as pd
|
41 |
+
from apscheduler.schedulers.background import BackgroundScheduler
|
42 |
+
|
43 |
+
import random
|
44 |
+
import time
|
45 |
+
from PIL import Image
|
46 |
+
from safety_checker_improved import maybe_nsfw
|
47 |
+
|
48 |
+
|
49 |
+
torch.set_grad_enabled(False)
|
50 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
51 |
+
torch.backends.cudnn.allow_tf32 = True
|
52 |
+
|
53 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
54 |
+
|
55 |
+
import spaces
|
56 |
+
prompt_list = [p for p in list(set(
|
57 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
58 |
+
|
59 |
+
start_time = time.time()
|
60 |
+
|
61 |
+
####################### Setup Model
|
62 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, AutoencoderTiny, UNet2DConditionModel, AutoencoderKL
|
63 |
+
from transformers import CLIPTextModel
|
64 |
+
from huggingface_hub import hf_hub_download
|
65 |
+
from safetensors.torch import load_file
|
66 |
+
from PIL import Image
|
67 |
+
from transformers import CLIPVisionModelWithProjection
|
68 |
+
import uuid
|
69 |
+
import av
|
70 |
+
|
71 |
+
def write_video(file_name, images, fps=17):
|
72 |
+
print('Saving')
|
73 |
+
container = av.open(file_name, mode="w")
|
74 |
+
|
75 |
+
stream = container.add_stream("h264", rate=fps)
|
76 |
+
# stream.options = {'preset': 'faster'}
|
77 |
+
stream.thread_count = 0
|
78 |
+
stream.width = 512
|
79 |
+
stream.height = 512
|
80 |
+
stream.pix_fmt = "yuv420p"
|
81 |
+
|
82 |
+
for img in images:
|
83 |
+
img = np.array(img)
|
84 |
+
img = np.round(img).astype(np.uint8)
|
85 |
+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
86 |
+
for packet in stream.encode(frame):
|
87 |
+
container.mux(packet)
|
88 |
+
# Flush stream
|
89 |
+
for packet in stream.encode():
|
90 |
+
container.mux(packet)
|
91 |
+
# Close the file
|
92 |
+
container.close()
|
93 |
+
print('Saved')
|
94 |
+
|
95 |
+
|
96 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype).to(DEVICE)
|
97 |
+
#vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
98 |
+
|
99 |
+
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
100 |
+
# vae = compile_unet(vae, config=config)
|
101 |
+
|
102 |
+
#finetune_path = '''/home/ryn_mote/Misc/finetune-sd1.5/dreambooth-model best'''''
|
103 |
+
#unet = UNet2DConditionModel.from_pretrained(finetune_path+'/unet/').to(dtype)
|
104 |
+
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
105 |
+
|
106 |
+
|
107 |
+
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype)
|
108 |
+
text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder').to(dtype)
|
109 |
+
|
110 |
+
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
111 |
+
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
|
112 |
+
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
113 |
+
pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
114 |
+
pipe.set_adapters(["lcm-lora"], [.9])
|
115 |
+
pipe.fuse_lora()
|
116 |
+
|
117 |
+
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
118 |
+
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
119 |
+
#repo = "ByteDance/AnimateDiff-Lightning"
|
120 |
+
#ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
|
121 |
+
|
122 |
+
|
123 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15_vit-G.bin", map_location='cpu')
|
124 |
+
# This IP adapter improves outputs substantially.
|
125 |
+
pipe.set_ip_adapter_scale(.8)
|
126 |
+
pipe.unet.fuse_qkv_projections()
|
127 |
+
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
128 |
+
|
129 |
+
#pipe = compile(pipe, config=config)
|
130 |
+
pipe.to(device=DEVICE)
|
131 |
+
#pipe.unet = torch.compile(pipe.unet)
|
132 |
+
#pipe.vae = torch.compile(pipe.vae)
|
133 |
+
|
134 |
+
|
135 |
+
im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
|
136 |
+
output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
137 |
+
leave_im_emb, _ = pipe.encode_image(
|
138 |
+
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
139 |
+
)
|
140 |
+
assert len(output.frames[0]) == 16
|
141 |
+
leave_im_emb.detach().to('cpu')
|
142 |
+
|
143 |
+
|
144 |
+
@spaces.GPU()
|
145 |
+
def generate(in_im_embs):
|
146 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
147 |
+
#im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
148 |
+
|
149 |
+
output = pipe(prompt='a scene', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
150 |
+
|
151 |
+
im_emb, _ = pipe.encode_image(
|
152 |
+
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
153 |
+
)
|
154 |
+
im_emb = im_emb.detach().to('cpu')
|
155 |
+
|
156 |
+
nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
|
157 |
+
|
158 |
+
name = str(uuid.uuid4()).replace("-", "")
|
159 |
+
path = f"/tmp/{name}.mp4"
|
160 |
+
|
161 |
+
if nsfw:
|
162 |
+
gr.Warning("NSFW content detected.")
|
163 |
+
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
164 |
+
return None, im_emb
|
165 |
+
|
166 |
+
|
167 |
+
output.frames[0] = output.frames[0] + list(reversed(output.frames[0]))
|
168 |
+
|
169 |
+
write_video(path, output.frames[0])
|
170 |
+
return path, im_emb
|
171 |
+
|
172 |
+
|
173 |
+
#######################
|
174 |
+
|
175 |
+
# TODO add to state instead of shared across all
|
176 |
+
glob_idx = 0
|
177 |
+
|
178 |
+
# TODO
|
179 |
+
# We can keep a df of media paths, embeddings, and user ratings.
|
180 |
+
# We can drop by lowest user ratings to keep enough RAM available when we get too many rows.
|
181 |
+
# We can continuously update by who is most recently active in the background & server as we go, plucking using "has been seen" and similarity
|
182 |
+
# to user embeds
|
183 |
+
|
184 |
+
def get_user_emb(embs, ys):
|
185 |
+
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
186 |
+
if len(list(set(ys))) <= 1:
|
187 |
+
embs.append(.01*torch.randn(1280))
|
188 |
+
embs.append(.01*torch.randn(1280))
|
189 |
+
ys.append(0)
|
190 |
+
ys.append(1)
|
191 |
+
print('Fixing only one feedback class available.\n')
|
192 |
+
|
193 |
+
indices = list(range(len(embs)))
|
194 |
+
# sample only as many negatives as there are positives
|
195 |
+
pos_indices = [i for i in indices if ys[i] == 1]
|
196 |
+
neg_indices = [i for i in indices if ys[i] == 0]
|
197 |
+
#lower = min(len(pos_indices), len(neg_indices))
|
198 |
+
#neg_indices = random.sample(neg_indices, lower)
|
199 |
+
#pos_indices = random.sample(pos_indices, lower)
|
200 |
+
print(len(neg_indices), len(pos_indices))
|
201 |
+
|
202 |
+
|
203 |
+
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
204 |
+
# this ends up adding a rating but losing an embedding, it seems.
|
205 |
+
# let's take off a rating if so to continue without indexing errors.
|
206 |
+
if len(ys) > len(embs):
|
207 |
+
print('ys are longer than embs; popping latest rating')
|
208 |
+
ys.pop(-1)
|
209 |
+
|
210 |
+
feature_embs = np.array(torch.stack([embs[i].squeeze().to('cpu') for i in indices] + [leave_im_emb.to('cpu').squeeze()]).to('cpu'))
|
211 |
+
#scaler = preprocessing.StandardScaler().fit(feature_embs)
|
212 |
+
#feature_embs = scaler.transform(feature_embs)
|
213 |
+
chosen_y = np.array([ys[i] for i in indices] + [0])
|
214 |
+
|
215 |
+
print('Gathering coefficients')
|
216 |
+
#lin_class = Ridge(fit_intercept=False).fit(feature_embs, chosen_y)
|
217 |
+
lin_class = SVC(max_iter=50000, kernel='linear', C=.1, class_weight='balanced').fit(feature_embs, chosen_y)
|
218 |
+
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double).detach().to('cpu')
|
219 |
+
coef_ = coef_ / coef_.abs().max() * 3
|
220 |
+
print('Gathered')
|
221 |
+
|
222 |
+
w = 1# if len(embs) % 2 == 0 else 0
|
223 |
+
im_emb = w * coef_.to(dtype=dtype)
|
224 |
+
return im_emb
|
225 |
+
|
226 |
+
|
227 |
+
def pluck_img(user_id, user_emb):
|
228 |
+
print(user_id, 'user_id')
|
229 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
230 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
231 |
+
while len(not_rated_rows) == 0:
|
232 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
233 |
+
time.sleep(.01)
|
234 |
+
# TODO optimize this lol
|
235 |
+
best_sim = -100000
|
236 |
+
for i in not_rated_rows.iterrows():
|
237 |
+
# TODO sloppy .to but it is 3am.
|
238 |
+
sim = torch.cosine_similarity(i[1]['embeddings'].detach().to('cpu'), user_emb.detach().to('cpu'))
|
239 |
+
if sim > best_sim:
|
240 |
+
best_sim = sim
|
241 |
+
best_row = i[1]
|
242 |
+
img = best_row['paths']
|
243 |
+
return img
|
244 |
+
|
245 |
+
|
246 |
+
def background_next_image():
|
247 |
+
global prevs_df
|
248 |
+
|
249 |
+
# only let it get N (maybe 3) ahead of the user
|
250 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
251 |
+
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
252 |
+
while len(not_rated_rows) > 8 or len(rated_rows) < 4:
|
253 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'] == {' ': ' '} for i in prevs_df.iterrows()]]
|
254 |
+
rated_rows = prevs_df[[i[1]['user:rating'] != {' ': ' '} for i in prevs_df.iterrows()]]
|
255 |
+
time.sleep(.01)
|
256 |
+
|
257 |
+
print(rated_rows['latest_user_to_rate'])
|
258 |
+
latest_user_id = rated_rows.iloc[-1]['latest_user_to_rate']
|
259 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(latest_user_id, None) is not None for i in prevs_df.iterrows()]]
|
260 |
+
|
261 |
+
print(latest_user_id)
|
262 |
+
embs, ys = pluck_embs_ys(latest_user_id)
|
263 |
+
|
264 |
+
user_emb = get_user_emb(embs, ys)
|
265 |
+
img, embs = generate(user_emb)
|
266 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
267 |
+
tmp_df['paths'] = [img]
|
268 |
+
tmp_df['embeddings'] = [embs]
|
269 |
+
tmp_df['user:rating'] = [{' ': ' '}]
|
270 |
+
prevs_df = pd.concat((prevs_df, tmp_df))
|
271 |
+
# we can free up storage by deleting the image
|
272 |
+
if len(prevs_df) > 50:
|
273 |
+
oldest_path = prevs_df.iloc[0]['paths']
|
274 |
+
if os.path.isfile(oldest_path):
|
275 |
+
os.remove(oldest_path)
|
276 |
+
else:
|
277 |
+
# If it fails, inform the user.
|
278 |
+
print("Error: %s file not found" % oldest_path)
|
279 |
+
# only keep 50 images & embeddings & ips, then remove oldest
|
280 |
+
prevs_df = prevs_df.iloc[1:]
|
281 |
+
|
282 |
+
|
283 |
+
def pluck_embs_ys(user_id):
|
284 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
285 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
286 |
+
while len(not_rated_rows) == 0:
|
287 |
+
not_rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) == None for i in prevs_df.iterrows()]]
|
288 |
+
rated_rows = prevs_df[[i[1]['user:rating'].get(user_id, None) != None for i in prevs_df.iterrows()]]
|
289 |
+
time.sleep(.01)
|
290 |
+
|
291 |
+
embs = rated_rows['embeddings'].to_list()
|
292 |
+
ys = [i[user_id] for i in rated_rows['user:rating'].to_list()]
|
293 |
+
print('embs', 'ys', embs, ys)
|
294 |
+
return embs, ys
|
295 |
+
|
296 |
+
def next_image(calibrate_prompts, user_id):
|
297 |
+
global glob_idx
|
298 |
+
glob_idx = glob_idx + 1
|
299 |
+
|
300 |
+
with torch.no_grad():
|
301 |
+
if len(calibrate_prompts) > 0:
|
302 |
+
print('######### Calibrating with sample media #########')
|
303 |
+
cal_video = calibrate_prompts.pop(0)
|
304 |
+
image = prevs_df[prevs_df['paths'] == cal_video]['paths'].to_list()[0]
|
305 |
+
|
306 |
+
return image, calibrate_prompts
|
307 |
+
else:
|
308 |
+
print('######### Roaming #########')
|
309 |
+
embs, ys = pluck_embs_ys(user_id)
|
310 |
+
user_emb = get_user_emb(embs, ys)
|
311 |
+
image = pluck_img(user_id, user_emb)
|
312 |
+
return image, calibrate_prompts
|
313 |
+
|
314 |
+
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
|
320 |
+
|
321 |
+
|
322 |
+
def start(_, calibrate_prompts, user_id, request: gr.Request):
|
323 |
+
image, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
324 |
+
return [
|
325 |
+
gr.Button(value='Like (L)', interactive=True),
|
326 |
+
gr.Button(value='Neither (Space)', interactive=True),
|
327 |
+
gr.Button(value='Dislike (A)', interactive=True),
|
328 |
+
gr.Button(value='Start', interactive=False),
|
329 |
+
image,
|
330 |
+
calibrate_prompts
|
331 |
+
]
|
332 |
+
|
333 |
+
|
334 |
+
def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
335 |
+
global prevs_df
|
336 |
+
|
337 |
+
|
338 |
+
if choice == 'Like (L)':
|
339 |
+
choice = 1
|
340 |
+
elif choice == 'Neither (Space)':
|
341 |
+
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
342 |
+
return img, calibrate_prompts
|
343 |
+
else:
|
344 |
+
choice = 0
|
345 |
+
|
346 |
+
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
347 |
+
# TODO skip allowing rating & just continue
|
348 |
+
if img == None:
|
349 |
+
print('NSFW -- choice is disliked')
|
350 |
+
choice = 0
|
351 |
+
|
352 |
+
# TODO clean up
|
353 |
+
old_d = prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'user:rating'][0]
|
354 |
+
old_d[user_id] = choice
|
355 |
+
prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'user:rating'][0] = old_d
|
356 |
+
prevs_df.loc[[p.split('/')[-1] in img for p in prevs_df['paths'].to_list()], 'latest_user_to_rate'] = [user_id]
|
357 |
+
print('full_df, prevs_df', prevs_df, prevs_df['latest_user_to_rate'])
|
358 |
+
|
359 |
+
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
360 |
+
return img, calibrate_prompts
|
361 |
+
|
362 |
+
css = '''.gradio-container{max-width: 700px !important}
|
363 |
+
#description{text-align: center}
|
364 |
+
#description h1, #description h3{display: block}
|
365 |
+
#description p{margin-top: 0}
|
366 |
+
.fade-in-out {animation: fadeInOut 3s forwards}
|
367 |
+
@keyframes fadeInOut {
|
368 |
+
0% {
|
369 |
+
background: var(--bg-color);
|
370 |
+
}
|
371 |
+
100% {
|
372 |
+
background: var(--button-secondary-background-fill);
|
373 |
+
}
|
374 |
+
}
|
375 |
+
'''
|
376 |
+
js_head = '''
|
377 |
+
<script>
|
378 |
+
document.addEventListener('keydown', function(event) {
|
379 |
+
if (event.key === 'a' || event.key === 'A') {
|
380 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
381 |
+
document.getElementById('dislike').click();
|
382 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
383 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
384 |
+
document.getElementById('neither').click();
|
385 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
386 |
+
// Trigger click on 'like' if 'L' is pressed
|
387 |
+
document.getElementById('like').click();
|
388 |
+
}
|
389 |
+
});
|
390 |
+
function fadeInOut(button, color) {
|
391 |
+
button.style.setProperty('--bg-color', color);
|
392 |
+
button.classList.remove('fade-in-out');
|
393 |
+
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
|
394 |
+
|
395 |
+
button.classList.add('fade-in-out');
|
396 |
+
button.addEventListener('animationend', () => {
|
397 |
+
button.classList.remove('fade-in-out'); // Reset the animation state
|
398 |
+
}, {once: true});
|
399 |
+
}
|
400 |
+
document.body.addEventListener('click', function(event) {
|
401 |
+
const target = event.target;
|
402 |
+
if (target.id === 'dislike') {
|
403 |
+
fadeInOut(target, '#ff1717');
|
404 |
+
} else if (target.id === 'like') {
|
405 |
+
fadeInOut(target, '#006500');
|
406 |
+
} else if (target.id === 'neither') {
|
407 |
+
fadeInOut(target, '#cccccc');
|
408 |
+
}
|
409 |
+
});
|
410 |
+
|
411 |
+
</script>
|
412 |
+
'''
|
413 |
+
|
414 |
+
with gr.Blocks(css=css, head=js_head) as demo:
|
415 |
+
gr.Markdown('''# Blue Tigers
|
416 |
+
### Generative Recommenders for Exporation of Video
|
417 |
+
|
418 |
+
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
419 |
+
''', elem_id="description")
|
420 |
+
user_id = gr.State(int(torch.randint(2**6, (1,))[0]))
|
421 |
+
calibrate_prompts = gr.State([
|
422 |
+
'./first.mp4',
|
423 |
+
'./second.mp4',
|
424 |
+
'./third.mp4',
|
425 |
+
'./fourth.mp4',
|
426 |
+
'./fifth.mp4',
|
427 |
+
'./sixth.mp4',
|
428 |
+
'./seventh.mp4',
|
429 |
+
])
|
430 |
+
def l():
|
431 |
+
return None
|
432 |
+
|
433 |
+
with gr.Row(elem_id='output-image'):
|
434 |
+
img = gr.Video(
|
435 |
+
label='Lightning',
|
436 |
+
autoplay=True,
|
437 |
+
interactive=False,
|
438 |
+
height=512,
|
439 |
+
width=512,
|
440 |
+
include_audio=False,
|
441 |
+
elem_id="video_output"
|
442 |
+
)
|
443 |
+
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
444 |
+
with gr.Row(equal_height=True):
|
445 |
+
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
446 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
447 |
+
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
448 |
+
b1.click(
|
449 |
+
choose,
|
450 |
+
[img, b1, calibrate_prompts, user_id],
|
451 |
+
[img, calibrate_prompts],
|
452 |
+
)
|
453 |
+
b2.click(
|
454 |
+
choose,
|
455 |
+
[img, b2, calibrate_prompts, user_id],
|
456 |
+
[img, calibrate_prompts],
|
457 |
+
)
|
458 |
+
b3.click(
|
459 |
+
choose,
|
460 |
+
[img, b3, calibrate_prompts, user_id],
|
461 |
+
[img, calibrate_prompts],
|
462 |
+
)
|
463 |
+
with gr.Row():
|
464 |
+
b4 = gr.Button(value='Start')
|
465 |
+
b4.click(start,
|
466 |
+
[b4, calibrate_prompts, user_id],
|
467 |
+
[b1, b2, b3, b4, img, calibrate_prompts]
|
468 |
+
)
|
469 |
+
with gr.Row():
|
470 |
+
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several videos and then roam. </ div><br><br><br>
|
471 |
+
<div style='text-align:center; font-size:14px'>Note that while the AnimateLCM model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
|
472 |
+
</ div>
|
473 |
+
<br><br>
|
474 |
+
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
|
475 |
+
</ div>''')
|
476 |
+
|
477 |
+
scheduler = BackgroundScheduler()
|
478 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=1)
|
479 |
+
scheduler.start()
|
480 |
+
|
481 |
+
# prep our calibration prompts
|
482 |
+
for im in [
|
483 |
+
'./first.mp4',
|
484 |
+
'./second.mp4',
|
485 |
+
'./third.mp4',
|
486 |
+
'./fourth.mp4',
|
487 |
+
'./fifth.mp4',
|
488 |
+
'./sixth.mp4',
|
489 |
+
'./seventh.mp4',
|
490 |
+
]:
|
491 |
+
tmp_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating'])
|
492 |
+
tmp_df['paths'] = [im]
|
493 |
+
image = list(imageio.imiter(im))
|
494 |
+
image = image[len(image)//2]
|
495 |
+
im_emb, _ = pipe.encode_image(
|
496 |
+
image, DEVICE, 1, output_hidden_state
|
497 |
+
)
|
498 |
+
|
499 |
+
tmp_df['embeddings'] = [im_emb.detach().to('cpu')]
|
500 |
+
tmp_df['user:rating'] = [{' ': ' '}]
|
501 |
+
prevs_df = pd.concat((prevs_df, tmp_df))
|
502 |
+
|
503 |
+
|
504 |
+
demo.launch(share=True)
|
license
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You may use this as you please iff you:
|
2 |
+
do not hold the authors liable for any issues you may encounter;
|
3 |
+
provide attribution by prominently linking to https://rynmurdock.github.io/posts/2024/3/generative_recomenders/ if you redistribute this code or use it within a product;
|
4 |
+
include the word "Tiger" within the name any products downstream of this.
|
lightning_app.py
ADDED
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
# lol
|
5 |
+
sidel = 512
|
6 |
+
DEVICE = 'cuda'
|
7 |
+
STEPS = 4
|
8 |
+
output_hidden_state = False
|
9 |
+
device = "cuda"
|
10 |
+
dtype = torch.float16
|
11 |
+
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
import matplotlib
|
14 |
+
matplotlib.use('TkAgg')
|
15 |
+
|
16 |
+
from sklearn.linear_model import LinearRegression
|
17 |
+
from sfast.compilers.diffusion_pipeline_compiler import (compile, compile_unet,
|
18 |
+
CompilationConfig)
|
19 |
+
config = CompilationConfig.Default()
|
20 |
+
|
21 |
+
try:
|
22 |
+
import triton
|
23 |
+
config.enable_triton = True
|
24 |
+
except ImportError:
|
25 |
+
print('Triton not installed, skip')
|
26 |
+
config.enable_cuda_graph = True
|
27 |
+
|
28 |
+
config.enable_jit = True
|
29 |
+
config.enable_jit_freeze = True
|
30 |
+
|
31 |
+
config.enable_cnn_optimization = True
|
32 |
+
config.preserve_parameters = False
|
33 |
+
config.prefer_lowp_gemm = True
|
34 |
+
|
35 |
+
import imageio
|
36 |
+
import gradio as gr
|
37 |
+
import numpy as np
|
38 |
+
from sklearn.svm import SVC
|
39 |
+
from sklearn.inspection import permutation_importance
|
40 |
+
from sklearn import preprocessing
|
41 |
+
import pandas as pd
|
42 |
+
|
43 |
+
import random
|
44 |
+
import time
|
45 |
+
from PIL import Image
|
46 |
+
from safety_checker_improved import maybe_nsfw
|
47 |
+
|
48 |
+
|
49 |
+
torch.set_grad_enabled(False)
|
50 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
51 |
+
torch.backends.cudnn.allow_tf32 = True
|
52 |
+
|
53 |
+
# TODO put back?
|
54 |
+
# import spaces
|
55 |
+
|
56 |
+
prompt_list = [p for p in list(set(
|
57 |
+
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
58 |
+
|
59 |
+
start_time = time.time()
|
60 |
+
|
61 |
+
####################### Setup Model
|
62 |
+
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler, LCMScheduler, ConsistencyDecoderVAE, AutoencoderTiny
|
63 |
+
from hyper_tile import split_attention, flush
|
64 |
+
from huggingface_hub import hf_hub_download
|
65 |
+
from safetensors.torch import load_file
|
66 |
+
from PIL import Image
|
67 |
+
from transformers import CLIPVisionModelWithProjection
|
68 |
+
import uuid
|
69 |
+
import av
|
70 |
+
|
71 |
+
def write_video(file_name, images, fps=10):
|
72 |
+
print('Saving')
|
73 |
+
container = av.open(file_name, mode="w")
|
74 |
+
|
75 |
+
stream = container.add_stream("h264", rate=fps)
|
76 |
+
stream.width = sidel
|
77 |
+
stream.height = sidel
|
78 |
+
stream.pix_fmt = "yuv420p"
|
79 |
+
|
80 |
+
for img in images:
|
81 |
+
img = np.array(img)
|
82 |
+
img = np.round(img).astype(np.uint8)
|
83 |
+
frame = av.VideoFrame.from_ndarray(img, format="rgb24")
|
84 |
+
for packet in stream.encode(frame):
|
85 |
+
container.mux(packet)
|
86 |
+
# Flush stream
|
87 |
+
for packet in stream.encode():
|
88 |
+
container.mux(packet)
|
89 |
+
# Close the file
|
90 |
+
container.close()
|
91 |
+
print('Saved')
|
92 |
+
|
93 |
+
bases = {
|
94 |
+
#"basem": "emilianJR/epiCRealism"
|
95 |
+
#SG161222/Realistic_Vision_V6.0_B1_noVAE
|
96 |
+
#runwayml/stable-diffusion-v1-5
|
97 |
+
#frankjoshua/realisticVisionV51_v51VAE
|
98 |
+
#Lykon/dreamshaper-7
|
99 |
+
}
|
100 |
+
|
101 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=dtype).to(DEVICE)
|
102 |
+
vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
103 |
+
|
104 |
+
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
105 |
+
# vae = compile_unet(vae, config=config)
|
106 |
+
|
107 |
+
#adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
108 |
+
#pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype)
|
109 |
+
#pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
110 |
+
#pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora",)
|
111 |
+
#pipe.set_adapters(["lcm-lora"], [1])
|
112 |
+
#pipe.fuse_lora()
|
113 |
+
|
114 |
+
pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder, vae=vae)
|
115 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
116 |
+
repo = "ByteDance/AnimateDiff-Lightning"
|
117 |
+
ckpt = f"animatediff_lightning_4step_diffusers.safetensors"
|
118 |
+
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device='cpu'), strict=False)
|
119 |
+
|
120 |
+
|
121 |
+
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin", map_location='cpu')
|
122 |
+
pipe.set_ip_adapter_scale(.8)
|
123 |
+
# pipe.unet.fuse_qkv_projections()
|
124 |
+
#pipe.enable_free_init(method="gaussian", use_fast_sampling=True)
|
125 |
+
|
126 |
+
pipe = compile(pipe, config=config)
|
127 |
+
pipe.to(device=DEVICE)
|
128 |
+
|
129 |
+
|
130 |
+
# THIS WOULD NEED PATCHING TODO
|
131 |
+
with split_attention(pipe.vae, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
132 |
+
# ! Change the tile_size and disable to see their effects
|
133 |
+
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
134 |
+
im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
|
135 |
+
output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
136 |
+
leave_im_emb, _ = pipe.encode_image(
|
137 |
+
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
138 |
+
)
|
139 |
+
assert len(output.frames[0]) == 16
|
140 |
+
leave_im_emb.to('cpu')
|
141 |
+
|
142 |
+
|
143 |
+
# TODO put back
|
144 |
+
# @spaces.GPU()
|
145 |
+
def generate(prompt, in_im_embs=None, base='basem'):
|
146 |
+
|
147 |
+
if in_im_embs == None:
|
148 |
+
in_im_embs = torch.zeros(1, 1, 1, 1024, device=DEVICE, dtype=dtype)
|
149 |
+
#in_im_embs = in_im_embs / torch.norm(in_im_embs)
|
150 |
+
else:
|
151 |
+
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
152 |
+
#im_embs = torch.cat((torch.zeros(1, 1024, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
153 |
+
|
154 |
+
with split_attention(pipe.unet, tile_size=128, swap_size=2, disable=False, aspect_ratio=1):
|
155 |
+
# ! Change the tile_size and disable to see their effects
|
156 |
+
with split_attention(pipe.vae, tile_size=128, disable=False, aspect_ratio=1):
|
157 |
+
output = pipe(prompt=prompt, guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
158 |
+
|
159 |
+
im_emb, _ = pipe.encode_image(
|
160 |
+
output.frames[0][len(output.frames[0])//2], DEVICE, 1, output_hidden_state
|
161 |
+
)
|
162 |
+
|
163 |
+
nsfw = maybe_nsfw(output.frames[0][len(output.frames[0])//2])
|
164 |
+
|
165 |
+
name = str(uuid.uuid4()).replace("-", "")
|
166 |
+
path = f"/tmp/{name}.mp4"
|
167 |
+
|
168 |
+
if nsfw:
|
169 |
+
gr.Warning("NSFW content detected.")
|
170 |
+
# TODO could return an automatic dislike of auto dislike on the backend for neither as well; just would need refactoring.
|
171 |
+
return None, im_emb
|
172 |
+
|
173 |
+
plt.close('all')
|
174 |
+
plt.hist(np.array(im_emb.to('cpu')).flatten(), bins=5)
|
175 |
+
plt.savefig('real_im_emb_plot.jpg')
|
176 |
+
|
177 |
+
write_video(path, output.frames[0])
|
178 |
+
return path, im_emb.to('cpu')
|
179 |
+
|
180 |
+
|
181 |
+
#######################
|
182 |
+
|
183 |
+
# TODO add to state instead of shared across all
|
184 |
+
glob_idx = 0
|
185 |
+
|
186 |
+
def next_image(embs, ys, calibrate_prompts):
|
187 |
+
global glob_idx
|
188 |
+
glob_idx = glob_idx + 1
|
189 |
+
|
190 |
+
with torch.no_grad():
|
191 |
+
if len(calibrate_prompts) > 0:
|
192 |
+
print('######### Calibrating with sample prompts #########')
|
193 |
+
prompt = calibrate_prompts.pop(0)
|
194 |
+
print(prompt)
|
195 |
+
image, img_embs = generate(prompt)
|
196 |
+
embs += img_embs
|
197 |
+
print(len(embs))
|
198 |
+
return image, embs, ys, calibrate_prompts
|
199 |
+
else:
|
200 |
+
print('######### Roaming #########')
|
201 |
+
|
202 |
+
# sample a .8 of rated embeddings for some stochasticity, or at least two embeddings.
|
203 |
+
# could take a sample < len(embs)
|
204 |
+
#n_to_choose = max(int((len(embs))), 2)
|
205 |
+
#indices = random.sample(range(len(embs)), n_to_choose)
|
206 |
+
|
207 |
+
# sample only as many negatives as there are positives
|
208 |
+
#pos_indices = [i for i in indices if ys[i] == 1]
|
209 |
+
#neg_indices = [i for i in indices if ys[i] == 0]
|
210 |
+
#lower = min(len(pos_indices), len(neg_indices))
|
211 |
+
#neg_indices = random.sample(neg_indices, lower)
|
212 |
+
#pos_indices = random.sample(pos_indices, lower)
|
213 |
+
#indices = neg_indices + pos_indices
|
214 |
+
|
215 |
+
pos_indices = [i for i in range(len(embs)) if ys[i] == 1]
|
216 |
+
neg_indices = [i for i in range(len(embs)) if ys[i] == 0]
|
217 |
+
|
218 |
+
# the embs & ys stay tied by index but we shuffle to drop randomly
|
219 |
+
random.shuffle(pos_indices)
|
220 |
+
random.shuffle(neg_indices)
|
221 |
+
|
222 |
+
#if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80:
|
223 |
+
# pos_indices = pos_indices[32:]
|
224 |
+
if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 120/16:
|
225 |
+
pos_indices = pos_indices[1:]
|
226 |
+
if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 200/16:
|
227 |
+
neg_indices = neg_indices[2:]
|
228 |
+
|
229 |
+
|
230 |
+
print(len(pos_indices), len(neg_indices))
|
231 |
+
indices = pos_indices + neg_indices
|
232 |
+
|
233 |
+
embs = [embs[i] for i in indices]
|
234 |
+
ys = [ys[i] for i in indices]
|
235 |
+
indices = list(range(len(embs)))
|
236 |
+
|
237 |
+
|
238 |
+
# handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike'
|
239 |
+
if len(list(set(ys))) <= 1:
|
240 |
+
embs.append(.01*torch.randn(1024))
|
241 |
+
embs.append(.01*torch.randn(1024))
|
242 |
+
ys.append(0)
|
243 |
+
ys.append(1)
|
244 |
+
|
245 |
+
|
246 |
+
# also add the latest 0 and the latest 1
|
247 |
+
has_0 = False
|
248 |
+
has_1 = False
|
249 |
+
for i in reversed(range(len(ys))):
|
250 |
+
if ys[i] == 0 and has_0 == False:
|
251 |
+
indices.append(i)
|
252 |
+
has_0 = True
|
253 |
+
elif ys[i] == 1 and has_1 == False:
|
254 |
+
indices.append(i)
|
255 |
+
has_1 = True
|
256 |
+
if has_0 and has_1:
|
257 |
+
break
|
258 |
+
|
259 |
+
# we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749);
|
260 |
+
# this ends up adding a rating but losing an embedding, it seems.
|
261 |
+
# let's take off a rating if so to continue without indexing errors.
|
262 |
+
if len(ys) > len(embs):
|
263 |
+
print('ys are longer than embs; popping latest rating')
|
264 |
+
ys.pop(-1)
|
265 |
+
|
266 |
+
feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices] + [leave_im_emb[0].to('cpu')]).to('cpu'))
|
267 |
+
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
268 |
+
feature_embs = scaler.transform(feature_embs)
|
269 |
+
chosen_y = np.array([ys[i] for i in indices] + [0])
|
270 |
+
|
271 |
+
print('Gathering coefficients')
|
272 |
+
#lin_class = LinearRegression(fit_intercept=False).fit(feature_embs, chosen_y)
|
273 |
+
lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=1).fit(feature_embs, chosen_y)
|
274 |
+
coef_ = torch.tensor(lin_class.coef_, dtype=torch.double)
|
275 |
+
coef_ = coef_ / coef_.abs().max() * 3
|
276 |
+
print(coef_.shape, 'COEF')
|
277 |
+
|
278 |
+
plt.close('all')
|
279 |
+
plt.hist(np.array(coef_).flatten(), bins=5)
|
280 |
+
plt.savefig('plot.jpg')
|
281 |
+
print(coef_)
|
282 |
+
print('Gathered')
|
283 |
+
|
284 |
+
rng_prompt = random.choice(prompt_list)
|
285 |
+
w = 1# if len(embs) % 2 == 0 else 0
|
286 |
+
im_emb = w * coef_.to(dtype=dtype)
|
287 |
+
|
288 |
+
prompt= 'the scene' if glob_idx % 2 == 0 else rng_prompt
|
289 |
+
print(prompt)
|
290 |
+
image, im_emb = generate(prompt, im_emb)
|
291 |
+
embs += im_emb
|
292 |
+
|
293 |
+
if len(embs) > 700/16:
|
294 |
+
embs = embs[1:]
|
295 |
+
ys = ys[1:]
|
296 |
+
|
297 |
+
return image, embs, ys, calibrate_prompts
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
|
306 |
+
|
307 |
+
def start(_, embs, ys, calibrate_prompts):
|
308 |
+
image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
309 |
+
return [
|
310 |
+
gr.Button(value='Like (L)', interactive=True),
|
311 |
+
gr.Button(value='Neither (Space)', interactive=True),
|
312 |
+
gr.Button(value='Dislike (A)', interactive=True),
|
313 |
+
gr.Button(value='Start', interactive=False),
|
314 |
+
image,
|
315 |
+
embs,
|
316 |
+
ys,
|
317 |
+
calibrate_prompts
|
318 |
+
]
|
319 |
+
|
320 |
+
|
321 |
+
def choose(img, choice, embs, ys, calibrate_prompts):
|
322 |
+
if choice == 'Like (L)':
|
323 |
+
choice = 1
|
324 |
+
elif choice == 'Neither (Space)':
|
325 |
+
embs = embs[:-1]
|
326 |
+
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
327 |
+
return img, embs, ys, calibrate_prompts
|
328 |
+
else:
|
329 |
+
choice = 0
|
330 |
+
|
331 |
+
# if we detected NSFW, leave that area of latent space regardless of how they rated chosen.
|
332 |
+
# TODO skip allowing rating
|
333 |
+
if img == None:
|
334 |
+
print('NSFW -- choice is disliked')
|
335 |
+
choice = 0
|
336 |
+
|
337 |
+
ys += [choice]*1
|
338 |
+
img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
|
339 |
+
return img, embs, ys, calibrate_prompts
|
340 |
+
|
341 |
+
css = '''.gradio-container{max-width: 700px !important}
|
342 |
+
#description{text-align: center}
|
343 |
+
#description h1, #description h3{display: block}
|
344 |
+
#description p{margin-top: 0}
|
345 |
+
.fade-in-out {animation: fadeInOut 3s forwards}
|
346 |
+
@keyframes fadeInOut {
|
347 |
+
0% {
|
348 |
+
background: var(--bg-color);
|
349 |
+
}
|
350 |
+
100% {
|
351 |
+
background: var(--button-secondary-background-fill);
|
352 |
+
}
|
353 |
+
}
|
354 |
+
'''
|
355 |
+
js_head = '''
|
356 |
+
<script>
|
357 |
+
document.addEventListener('keydown', function(event) {
|
358 |
+
if (event.key === 'a' || event.key === 'A') {
|
359 |
+
// Trigger click on 'dislike' if 'A' is pressed
|
360 |
+
document.getElementById('dislike').click();
|
361 |
+
} else if (event.key === ' ' || event.keyCode === 32) {
|
362 |
+
// Trigger click on 'neither' if Spacebar is pressed
|
363 |
+
document.getElementById('neither').click();
|
364 |
+
} else if (event.key === 'l' || event.key === 'L') {
|
365 |
+
// Trigger click on 'like' if 'L' is pressed
|
366 |
+
document.getElementById('like').click();
|
367 |
+
}
|
368 |
+
});
|
369 |
+
function fadeInOut(button, color) {
|
370 |
+
button.style.setProperty('--bg-color', color);
|
371 |
+
button.classList.remove('fade-in-out');
|
372 |
+
void button.offsetWidth; // This line forces a repaint by accessing a DOM property
|
373 |
+
|
374 |
+
button.classList.add('fade-in-out');
|
375 |
+
button.addEventListener('animationend', () => {
|
376 |
+
button.classList.remove('fade-in-out'); // Reset the animation state
|
377 |
+
}, {once: true});
|
378 |
+
}
|
379 |
+
document.body.addEventListener('click', function(event) {
|
380 |
+
const target = event.target;
|
381 |
+
if (target.id === 'dislike') {
|
382 |
+
fadeInOut(target, '#ff1717');
|
383 |
+
} else if (target.id === 'like') {
|
384 |
+
fadeInOut(target, '#006500');
|
385 |
+
} else if (target.id === 'neither') {
|
386 |
+
fadeInOut(target, '#cccccc');
|
387 |
+
}
|
388 |
+
});
|
389 |
+
|
390 |
+
</script>
|
391 |
+
'''
|
392 |
+
|
393 |
+
with gr.Blocks(css=css, head=js_head) as demo:
|
394 |
+
gr.Markdown('''### Blue Tigers: Generative Recommenders for Exporation of Video
|
395 |
+
Explore the latent space without text prompts based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/).
|
396 |
+
''', elem_id="description")
|
397 |
+
embs = gr.State([])
|
398 |
+
ys = gr.State([])
|
399 |
+
calibrate_prompts = gr.State([
|
400 |
+
'the moon is melting into my glass of tea',
|
401 |
+
'a sea slug -- pair of claws scuttling -- jelly fish glowing',
|
402 |
+
'an adorable creature. It may be a goblin or a pig or a slug.',
|
403 |
+
'an animation about a gorgeous nebula',
|
404 |
+
'an octopus writhes',
|
405 |
+
])
|
406 |
+
def l():
|
407 |
+
return None
|
408 |
+
|
409 |
+
with gr.Row(elem_id='output-image'):
|
410 |
+
img = gr.Video(
|
411 |
+
label='Lightning',
|
412 |
+
autoplay=True,
|
413 |
+
interactive=False,
|
414 |
+
height=sidel,
|
415 |
+
width=sidel,
|
416 |
+
include_audio=False,
|
417 |
+
elem_id="video_output"
|
418 |
+
)
|
419 |
+
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
420 |
+
with gr.Row(equal_height=True):
|
421 |
+
b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
|
422 |
+
b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
|
423 |
+
b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
|
424 |
+
b1.click(
|
425 |
+
choose,
|
426 |
+
[img, b1, embs, ys, calibrate_prompts],
|
427 |
+
[img, embs, ys, calibrate_prompts]
|
428 |
+
)
|
429 |
+
b2.click(
|
430 |
+
choose,
|
431 |
+
[img, b2, embs, ys, calibrate_prompts],
|
432 |
+
[img, embs, ys, calibrate_prompts]
|
433 |
+
)
|
434 |
+
b3.click(
|
435 |
+
choose,
|
436 |
+
[img, b3, embs, ys, calibrate_prompts],
|
437 |
+
[img, embs, ys, calibrate_prompts]
|
438 |
+
)
|
439 |
+
with gr.Row():
|
440 |
+
b4 = gr.Button(value='Start')
|
441 |
+
b4.click(start,
|
442 |
+
[b4, embs, ys, calibrate_prompts],
|
443 |
+
[b1, b2, b3, b4, img, embs, ys, calibrate_prompts])
|
444 |
+
with gr.Row():
|
445 |
+
html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br>
|
446 |
+
<div style='text-align:center; font-size:14px'>Note that while the AnimateDiff-Lightning model with NSFW filtering is unlikely to produce NSFW images, this may still occur, and users should avoid NSFW content when rating.
|
447 |
+
</ div>
|
448 |
+
<br><br>
|
449 |
+
<div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback.
|
450 |
+
</ div>''')
|
451 |
+
|
452 |
+
demo.launch(share=True)
|
safety_checker_improved.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
# TODO required tensorflow==2.14 for me
|
3 |
+
# weights from https://github.com/LAION-AI/safety-pipeline/tree/main
|
4 |
+
from PIL import Image
|
5 |
+
import tensorflow_hub as hub
|
6 |
+
import tensorflow
|
7 |
+
import numpy as np
|
8 |
+
import sys
|
9 |
+
sys.path.append('/home/ryn_mote/Misc/generative_recommender/gradio_video/automl/efficientnetv2/')
|
10 |
+
import tensorflow as tf
|
11 |
+
from tensorflow.keras import mixed_precision
|
12 |
+
physical_devices = tf.config.list_physical_devices('GPU')
|
13 |
+
|
14 |
+
tf.config.experimental.set_memory_growth(
|
15 |
+
physical_devices[0], True
|
16 |
+
)
|
17 |
+
|
18 |
+
model = tf.keras.models.load_model('nsfweffnetv2-b02-3epochs.h5',custom_objects={"KerasLayer":hub.KerasLayer})
|
19 |
+
# "The image classifier had been trained on 682550 images from the 5 classes "Drawing" (39026), "Hentai" (28134), "Neutral" (369507), "Porn" (207969) & "Sexy" (37914).
|
20 |
+
# ... we created a manually inspected test set that consists of 4900 samples, that contains images & their captions."
|
21 |
+
|
22 |
+
# Run prediction
|
23 |
+
def maybe_nsfw(pil_image):
|
24 |
+
# Run prediction
|
25 |
+
imm = tensorflow.image.resize(np.array(pil_image)[:, :, :3], (260, 260))
|
26 |
+
imm = (imm / 255)
|
27 |
+
pred = model(tensorflow.expand_dims(imm, 0)).numpy()
|
28 |
+
probs = tensorflow.math.softmax(pred[0]).numpy()
|
29 |
+
print(probs)
|
30 |
+
if all([i < .3 for i in probs[[1, 3, 4]]]):
|
31 |
+
return False
|
32 |
+
return True
|
33 |
+
|
34 |
+
# pre-initializing prediction
|
35 |
+
maybe_nsfw(Image. new("RGB", (260, 260), 255))
|
36 |
+
model.load_weights('nsfweffnetv2-b02-3epochs.h5', by_name=True, )
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
twitter_prompts.csv
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,0
|
2 |
+
0,a sunset
|
3 |
+
1,a still life in blue
|
4 |
+
2,last day on earth
|
5 |
+
3,the conch shell
|
6 |
+
4,the winds of change
|
7 |
+
5,a surrealist eye
|
8 |
+
6,a surrealist polaroid photo of an apple
|
9 |
+
7,metaphysics
|
10 |
+
8,the sun is setting into my glass of tea
|
11 |
+
9,the moon at 3am
|
12 |
+
10,a memento mori
|
13 |
+
11,quaking aspen tree
|
14 |
+
12,violets and daffodils
|
15 |
+
13,espresso
|
16 |
+
14,sisyphus
|
17 |
+
15,high windows of stained glass
|
18 |
+
16,a green dog
|
19 |
+
17,an adorable companion; it is a pig
|
20 |
+
18,bird of paradise
|
21 |
+
19,a complex intricate machine
|
22 |
+
20,a white clock
|
23 |
+
21,a film featuring the landscape Salt Lake City Utah
|
24 |
+
22,a creature
|
25 |
+
23,a house set aflame.
|
26 |
+
24,a gorgeous landscape by Cy Twombly
|
27 |
+
25,smoke rises from the caterpillar's hookah
|
28 |
+
26,corvid in red
|
29 |
+
27,Monet's pond
|
30 |
+
28,Genesis
|
31 |
+
29,Death is a black camel that kneels down so we can ride
|
32 |
+
30,a cherry tree made of fractals
|
33 |
+
29,the end of the sidewalk
|
34 |
+
30,a polaroid photo of a bustling city of lights and sky scrapers
|
35 |
+
31,The Fig Tree metaphor
|
36 |
+
32,God killed Van Gogh.
|
37 |
+
33,a cosmic entity alien with four eyes.
|
38 |
+
34,a horse with 128 eyes.
|
39 |
+
35,a being with an infinite set of eyes (it is omniscient)
|
40 |
+
36,A sticky-note magnum opus featuring birds
|
41 |
+
37,Moka Pot
|
42 |
+
38,the moon is a sickle cell
|
43 |
+
39,The Penultimate Supper
|
44 |
+
40,Art
|
45 |
+
41,surrealism
|
46 |
+
42,a god made of wires & dust
|
47 |
+
43,a dandelion blown into the universe
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
|
72 |
+
|