DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd import random import time import replicate import torch import pickle from urllib.request import urlopen from PIL import Image import requests from io import BytesIO prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] calibrate_prompts = [ "4k photo", 'surrealist art', 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', 'a city full of darkness and graffiti', '', ] embs = [] ys = [] start_time = time.time() glob_idx = 0 def next_image(): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) output = replicate.run( "rynmurdock/zahir:49ebb1916c4baae35884ebfa16b092cf45d086c1913b53f62bb07d575cdbe683", input={"prompt": prompt,} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) embs.append(torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0)) return image else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) im_emb_st = str(im_emb[0].cpu().detach().tolist())[1:-1] output = StringIO() output.write(im_emb_st) output = replicate.run( "rynmurdock/zahir:49ebb1916c4baae35884ebfa16b092cf45d086c1913b53f62bb07d575cdbe683", input={"prompt": prompt, 'im_emb': output} ) response = requests.get(output['file1']) image = Image.open(BytesIO(response.content)) im_emb = torch.tensor([float(i) for i in urlopen(output['file2']).read().decode('utf-8').split(', ')]).unsqueeze(0) embs.append(im_emb) torch.save(lin_class.coef_, f'./{start_time}.pt') return image def start(_): return [ gr.Button(value='Like', interactive=True), gr.Button(value='Neither', interactive=True), gr.Button(value='Dislike', interactive=True), gr.Button(value='Start', interactive=False), next_image() ] def choose(choice): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) return next_image() else: choice = 0 ys.append(choice) return next_image() css = "div#output-image {height: 768px !important; width: 768px !important; margin:auto;}" with gr.Blocks(css=css) as demo: with gr.Row(): html = gr.HTML('''
You will callibrate for several prompts and then roam.''') with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike', interactive=False,) b2 = gr.Button(value='Neither', interactive=False,) b1 = gr.Button(value='Like', interactive=False,) b1.click( choose, [b1], [img] ) b2.click( choose, [b2], [img] ) b3.click( choose, [b3], [img] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4], [b1, b2, b3, b4, img,]) demo.launch() # Share your demo with just 1 extra parameter 🚀