DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd from diffusers import LCMScheduler from diffusers.models import ImageProjection from patch_sdxl import SDEmb import torch import spaces import random import time import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model model_id = "stabilityai/stable-diffusion-xl-base-1.0" lcm_lora_id = "latent-consistency/lcm-lora-sdxl" pipe = SDEmb.from_pretrained(model_id, variant="fp16", low_cpu_mem_usage=True, device_map="auto") pipe.load_lora_weights(lcm_lora_id) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.to(device='cuda', dtype=torch.float16) pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") output_hidden_state = False ####################### @spaces.GPU def predict( prompt, im_emb=None, progress=gr.Progress(track_tqdm=True) ): """Run a single prediction on the model""" with torch.no_grad(): if im_emb == None: im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda') image = pipe( prompt=prompt, ip_adapter_emb=im_emb.to('cuda'), height=1024, width=1024, num_inference_steps=8, guidance_scale=0, ).images[0] im_emb, _ = pipe.encode_image( image, 'cuda', 1, output_hidden_state ) return image, im_emb.to(DEVICE) # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1, 1280)) embs.append(.01*torch.randn(1, 1280)) ys.append(0) ys.append(1) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, img_emb = predict(prompt) embs.append(img_emb) return image, embs, ys, calibrate_prompts else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) image, im_emb = predict(prompt, im_emb) embs.append(im_emb) return image, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return [ gr.Button(value='Like', interactive=True), gr.Button(value='Neither', interactive=True), gr.Button(value='Dislike', interactive=True), gr.Button(value='Start', interactive=False), image, embs, ys, calibrate_prompts ] def choose(choice, embs, ys, calibrate_prompts): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts else: choice = 0 ys.append(choice) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts css = ".gradio-container{max-width: 700px !important}" print(css) with gr.Blocks(css=css) as demo: embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ "4k photo", 'surrealist art', # 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', # 'a city full of darkness and graffiti', '', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',width=700) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike', interactive=False,) b2 = gr.Button(value='Neither', interactive=False,) b1 = gr.Button(value='Like', interactive=False,) b1.click( choose, [b1, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b2.click( choose, [b2, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b3.click( choose, [b3, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''