rynmurdock multimodalart HF staff commited on
Commit
db551d5
1 Parent(s): 178e606

Performance PR (#2)

Browse files

- Performance PR (f33c43f609f59f8722b5928f0535007a9157da38)
- Disable SC (e6d1b5454f215a7280081510188907d11646de37)


Co-authored-by: Apolinário from multimodal AI art <multimodalart@users.noreply.huggingface.co>

Files changed (2) hide show
  1. app.py +46 -22
  2. patch_sdxl.py +4 -30
app.py CHANGED
@@ -6,7 +6,7 @@ from sklearn.svm import LinearSVC
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
- from diffusers import LCMScheduler
10
  from diffusers.models import ImageProjection
11
  from patch_sdxl import SDEmb
12
  import torch
@@ -22,6 +22,9 @@ from PIL import Image
22
  import requests
23
  from io import BytesIO, StringIO
24
 
 
 
 
25
  prompt_list = [p for p in list(set(
26
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
27
 
@@ -29,12 +32,16 @@ start_time = time.time()
29
 
30
  ####################### Setup Model
31
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
32
- lcm_lora_id = "latent-consistency/lcm-lora-sdxl"
33
- pipe = SDEmb.from_pretrained(model_id, variant="fp16", low_cpu_mem_usage=True, device_map="auto")
34
- pipe.load_lora_weights(lcm_lora_id)
35
- pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
36
- pipe.to(device='cuda', dtype=torch.float16)
 
 
 
37
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
 
38
  output_hidden_state = False
39
  #######################
40
 
@@ -53,7 +60,7 @@ def predict(
53
  ip_adapter_emb=im_emb.to('cuda'),
54
  height=1024,
55
  width=1024,
56
- num_inference_steps=8,
57
  guidance_scale=0,
58
  ).images[0]
59
  im_emb, _ = pipe.encode_image(
@@ -61,12 +68,6 @@ def predict(
61
  )
62
  return image, im_emb.to(DEVICE)
63
 
64
-
65
-
66
-
67
-
68
-
69
-
70
  # TODO add to state instead of shared across all
71
  glob_idx = 0
72
 
@@ -133,9 +134,9 @@ def next_image(embs, ys, calibrate_prompts):
133
  def start(_, embs, ys, calibrate_prompts):
134
  image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
135
  return [
136
- gr.Button(value='Like', interactive=True),
137
- gr.Button(value='Neither', interactive=True),
138
- gr.Button(value='Dislike', interactive=True),
139
  gr.Button(value='Start', interactive=False),
140
  image,
141
  embs,
@@ -157,9 +158,32 @@ def choose(choice, embs, ys, calibrate_prompts):
157
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
158
  return img, embs, ys, calibrate_prompts
159
 
160
- css = ".gradio-container{max-width: 700px !important}"
161
- print(css)
162
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  embs = gr.State([])
164
  ys = gr.State([])
165
  calibrate_prompts = gr.State([
@@ -177,9 +201,9 @@ with gr.Blocks(css=css) as demo:
177
  with gr.Row(elem_id='output-image'):
178
  img = gr.Image(interactive=False, elem_id='output-image',width=700)
179
  with gr.Row(equal_height=True):
180
- b3 = gr.Button(value='Dislike', interactive=False,)
181
- b2 = gr.Button(value='Neither', interactive=False,)
182
- b1 = gr.Button(value='Like', interactive=False,)
183
  b1.click(
184
  choose,
185
  [b1, embs, ys, calibrate_prompts],
 
6
  from sklearn import preprocessing
7
  import pandas as pd
8
 
9
+ from diffusers import LCMScheduler, AutoencoderTiny, EulerDiscreteScheduler, UNet2DConditionModel
10
  from diffusers.models import ImageProjection
11
  from patch_sdxl import SDEmb
12
  import torch
 
22
  import requests
23
  from io import BytesIO, StringIO
24
 
25
+ from huggingface_hub import hf_hub_download
26
+ from safetensors.torch import load_file
27
+
28
  prompt_list = [p for p in list(set(
29
  pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
30
 
 
32
 
33
  ####################### Setup Model
34
  model_id = "stabilityai/stable-diffusion-xl-base-1.0"
35
+ sdxl_lightening = "ByteDance/SDXL-Lightning"
36
+ ckpt = "sdxl_lightning_2step_unet.safetensors"
37
+ unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
38
+ unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
39
+ pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
41
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
42
+ pipe.to(device='cuda')
43
  pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
44
+
45
  output_hidden_state = False
46
  #######################
47
 
 
60
  ip_adapter_emb=im_emb.to('cuda'),
61
  height=1024,
62
  width=1024,
63
+ num_inference_steps=2,
64
  guidance_scale=0,
65
  ).images[0]
66
  im_emb, _ = pipe.encode_image(
 
68
  )
69
  return image, im_emb.to(DEVICE)
70
 
 
 
 
 
 
 
71
  # TODO add to state instead of shared across all
72
  glob_idx = 0
73
 
 
134
  def start(_, embs, ys, calibrate_prompts):
135
  image, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
136
  return [
137
+ gr.Button(value='Like (L)', interactive=True),
138
+ gr.Button(value='Neither (Space)', interactive=True),
139
+ gr.Button(value='Dislike (A)', interactive=True),
140
  gr.Button(value='Start', interactive=False),
141
  image,
142
  embs,
 
158
  img, embs, ys, calibrate_prompts = next_image(embs, ys, calibrate_prompts)
159
  return img, embs, ys, calibrate_prompts
160
 
161
+ css = '''.gradio-container{max-width: 700px !important}
162
+ #description{text-align: center}
163
+ #description h1{display: block}
164
+ #description p{margin-top: 0}
165
+ '''
166
+ js = '''
167
+ <script>
168
+ document.addEventListener('keydown', function(event) {
169
+ if (event.key === 'a' || event.key === 'A') {
170
+ // Trigger click on 'dislike' if 'A' is pressed
171
+ document.getElementById('dislike').click();
172
+ } else if (event.key === ' ' || event.keyCode === 32) {
173
+ // Trigger click on 'neither' if Spacebar is pressed
174
+ document.getElementById('neither').click();
175
+ } else if (event.key === 'l' || event.key === 'L') {
176
+ // Trigger click on 'like' if 'L' is pressed
177
+ document.getElementById('like').click();
178
+ }
179
+ });
180
+ </script>
181
+ '''
182
+
183
+ with gr.Blocks(css=css, head=js) as demo:
184
+ gr.Markdown('''# Generative Recommenders
185
+ Explore the latent space without text prompts, based on your preferences. [Learn more on the blog](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/)
186
+ ''', elem_id="description")
187
  embs = gr.State([])
188
  ys = gr.State([])
189
  calibrate_prompts = gr.State([
 
201
  with gr.Row(elem_id='output-image'):
202
  img = gr.Image(interactive=False, elem_id='output-image',width=700)
203
  with gr.Row(equal_height=True):
204
+ b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike")
205
+ b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither")
206
+ b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like")
207
  b1.click(
208
  choose,
209
  [b1, embs, ys, calibrate_prompts],
patch_sdxl.py CHANGED
@@ -1,6 +1,3 @@
1
-
2
-
3
-
4
  import inspect
5
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
6
 
@@ -29,7 +26,6 @@ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOut
29
 
30
 
31
 
32
- from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
33
  from transformers import CLIPFeatureExtractor
34
  import numpy as np
35
  import torch
@@ -40,27 +36,6 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
  torch_device = device
41
  torch_dtype = torch.float16
42
 
43
- safety_checker = StableDiffusionSafetyChecker.from_pretrained(
44
- "CompVis/stable-diffusion-safety-checker"
45
- ).to(device)
46
- feature_extractor = CLIPFeatureExtractor.from_pretrained(
47
- "openai/clip-vit-base-patch32"
48
- )
49
-
50
- def check_nsfw_images(
51
- images: list[Image.Image],
52
- ) -> list[bool]:
53
- safety_checker_input = feature_extractor(images, return_tensors="pt").to(device)
54
- images_np = [np.array(img) for img in images]
55
-
56
- _, has_nsfw_concepts = safety_checker(
57
- images=images_np,
58
- clip_input=safety_checker_input.pixel_values.to(torch_device),
59
- )
60
- return has_nsfw_concepts
61
-
62
-
63
-
64
 
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
66
 
@@ -569,12 +544,11 @@ class SDEmb(StableDiffusionXLPipeline):
569
  # apply watermark if available
570
  if self.watermark is not None:
571
  image = self.watermark.apply_watermark(image)
572
-
573
  image = self.image_processor.postprocess(image, output_type=output_type)
574
- maybe_nsfw = any(check_nsfw_images(image))
575
- if maybe_nsfw:
576
- print('This image could be NSFW so we return a blank image.')
577
- return StableDiffusionXLPipelineOutput(images=[Image.new('RGB', (1024, 1024))])
578
 
579
  # Offload all models
580
  self.maybe_free_model_hooks()
 
 
 
 
1
  import inspect
2
  from typing import Any, Callable, Dict, List, Optional, Union, Tuple
3
 
 
26
 
27
 
28
 
 
29
  from transformers import CLIPFeatureExtractor
30
  import numpy as np
31
  import torch
 
36
  torch_device = device
37
  torch_dtype = torch.float16
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
 
 
544
  # apply watermark if available
545
  if self.watermark is not None:
546
  image = self.watermark.apply_watermark(image)
 
547
  image = self.image_processor.postprocess(image, output_type=output_type)
548
+ #maybe_nsfw = any(check_nsfw_images(image))
549
+ #if maybe_nsfw:
550
+ # print('This image could be NSFW so we return a blank image.')
551
+ # return StableDiffusionXLPipelineOutput(images=[Image.new('RGB', (1024, 1024))])
552
 
553
  # Offload all models
554
  self.maybe_free_model_hooks()