DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd import random import time import replicate import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] calibrate_prompts = [ "4k photo", 'surrealist art', # 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', # 'a city full of darkness and graffiti', '', ] start_time = time.time() # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, ys): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) output = replicate.run( "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42", input={"prompt": prompt,} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) embs.append(torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0)) return image, embs, ys else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) im_emb_st = str(im_emb[0].cpu().detach().tolist())[1:-1] output = replicate.run( "rynmurdock/zahir:42c58addd49ab57f1e309f0b9a0f271f483bbef0470758757c623648fe989e42", input={"prompt": prompt, 'im_emb': im_emb_st} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) im_emb = torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0) embs.append(im_emb) torch.save(lin_class.coef_, f'./{start_time}.pt') return image, embs, ys def start(_): return [ gr.Button(value='Like', interactive=True), gr.Button(value='Neither', interactive=True), gr.Button(value='Dislike', interactive=True), gr.Button(value='Start', interactive=False), next_image() ] def choose(choice, embs, ys): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) return next_image(embs, ys) else: choice = 0 ys.append(choice) return next_image(embs, ys) css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}" with gr.Blocks(css=css) as demo: embs = gr.State([]) ys = gr.State([]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike', interactive=False,) b2 = gr.Button(value='Neither', interactive=False,) b1 = gr.Button(value='Like', interactive=False,) b1.click( choose, [b1, embs, ys], [img, embs, ys] ) b2.click( choose, [b2, embs, ys], [img, embs, ys] ) b3.click( choose, [b3, embs, ys], [img, embs, ys] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys], [b1, b2, b3, b4, img,]) with gr.Row(): html = gr.HTML('''