DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd import random import time import replicate import torch import pickle prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] calibrate_prompts = [ "4k photo", 'surrealist art', 'a psychedelic, fractal view', 'a beautiful collage', 'an intricate portrait', 'an impressionist painting', 'abstract art', 'an eldritch image', 'a sketch', 'a city full of darkness and graffiti', 'a black & white photo', 'a brilliant, timeless tarot card of the world', 'a photo of a woman', '', ] embs = [] ys = [] start_time = time.time() glob_idx = 0 def next_image(): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, pooled_embeds = replicate.run( "rynmurdock/zahir:f1b619cf16566f6262e69bed518b40a83c36f2042b5d6c3c748361cd9532abf7", input={"prompt": prompt,} ) embs.append(pooled_embeds) return image[0] else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) image, im_emb = replicate.run( "rynmurdock/zahir:f1b619cf16566f6262e69bed518b40a83c36f2042b5d6c3c748361cd9532abf7", input={"prompt": prompt, 'im_emb': pickle.dumps(im_emb)} ) embs.append(im_emb) torch.save(lin_class.coef_, f'./{start_time}.pt') return image[0] def start(_): return [ gr.Button(value='Like', interactive=True), gr.Button(value='Neither', interactive=True), gr.Button(value='Dislike', interactive=True), gr.Button(value='Start', interactive=False), next_image() ] def choose(choice): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) return next_image() else: choice = 0 ys.append(choice) return next_image() css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}" with gr.Blocks(css=css) as demo: with gr.Row(): html = gr.HTML('''