DEVICE = 'cpu' import gradio as gr import numpy as np from sklearn.svm import LinearSVC from sklearn import preprocessing import pandas as pd from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel from diffusers.models import ImageProjection from patch_sdxl import SDEmb import torch import spaces import random import time import torch from urllib.request import urlopen from PIL import Image import requests from io import BytesIO, StringIO from huggingface_hub import hf_hub_download from safetensors.torch import load_file prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model model_id = "stabilityai/stable-diffusion-xl-base-1.0" sdxl_lightening = "ByteDance/SDXL-Lightning" ckpt = "sdxl_lightning_2step_unet.safetensors" unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16) unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda")) pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") pipe.to(device='cuda') pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") output_hidden_state = False ####################### @spaces.GPU def predict( prompt, im_emb=None, progress=gr.Progress(track_tqdm=True) ): """Run a single prediction on the model""" with torch.no_grad(): if im_emb == None: im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda') image = pipe( prompt=prompt, ip_adapter_emb=im_emb.to('cuda'), height=1024, width=1024, num_inference_steps=2, guidance_scale=0, ).images[0] im_emb, _ = pipe.encode_image( image, 'cuda', 1, output_hidden_state ) return image, im_emb.to(DEVICE) # TODO add to state instead of shared across all glob_idx = 0 def next_image(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(calibrate_prompts) == 0 and len(list(set(ys))) <= 1: embs.append(.01*torch.randn(1, 1280)) embs.append(.01*torch.randn(1, 1280)) ys.append(0) ys.append(1) with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) image, img_emb = predict(prompt) embs.append(img_emb) return image, embs, ys, calibrate_prompts else: print('######### Roaming #########') # sample only as many negatives as there are positives indices = range(len(ys)) pos_indices = [i for i in indices if ys[i] == 1] neg_indices = [i for i in indices if ys[i] == 0] lower = min(len(pos_indices), len(neg_indices)) neg_indices = random.sample(neg_indices, lower) pos_indices = random.sample(pos_indices, lower) cut_embs = [embs[i] for i in neg_indices] + [embs[i] for i in pos_indices] cut_ys = [ys[i] for i in neg_indices] + [ys[i] for i in pos_indices] feature_embs = torch.stack([e[0].detach().cpu() for e in cut_embs]) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) print(np.array(feature_embs).shape, np.array(ys).shape) lin_class = LinearSVC(max_iter=50000, dual='auto', class_weight='balanced').fit(np.array(feature_embs), np.array(cut_ys)) lin_class.coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) lin_class.coef_ = (lin_class.coef_.flatten() / (lin_class.coef_.flatten().norm())).unsqueeze(0) rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * lin_class.coef_.to(device=DEVICE, dtype=torch.float16) prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt print(prompt) image, im_emb = predict(prompt, im_emb) embs.append(im_emb) return image, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), image, embs, ys, calibrate_prompts ] def choose(choice, embs, ys, calibrate_prompts): if choice == 'Like': choice = 1 elif choice == 'Neither': _ = embs.pop(-1) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts else: choice = 0 ys.append(choice) img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts) return img, embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1{display: block} #description p{margin-top: 0} ''' js = ''' ''' with gr.Blocks(css=css, head=js) as demo: gr.Markdown('''# Generative Recommenders Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/) ''', elem_id="description") embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ "4k photo", 'surrealist art', # 'a psychedelic, fractal view', 'a beautiful collage', 'abstract art', 'an eldritch image', 'a sketch', # 'a city full of darkness and graffiti', '', ]) with gr.Row(elem_id='output-image'): img = gr.Image(interactive=False, elem_id='output-image',width=700) with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [b1, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b2.click( choose, [b2, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) b3.click( choose, [b3, embs, ys, calibrate_prompts], [img, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, img, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''