surokpro2 commited on
Commit
13c384f
1 Parent(s): c8a2456

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -10,6 +10,7 @@ import matplotlib.pyplot as plt
10
  from matplotlib.colors import ListedColormap
11
  from utils import add_feature_on_area, replace_with_feature
12
  import threading
 
13
 
14
  code_to_block = {
15
  "down.2.1": "unet.down_blocks.2.attentions.1",
@@ -19,6 +20,7 @@ code_to_block = {
19
  }
20
  lock = threading.Lock()
21
 
 
22
  def process_cache(cache, saes_dict):
23
 
24
  top_features_dict = {}
@@ -65,6 +67,7 @@ def plot_image_heatmap(cache, block_select, radio):
65
 
66
 
67
  def create_prompt_part(pipe, saes_dict, demo):
 
68
  def image_gen(prompt):
69
  lock.acquire()
70
  try:
@@ -137,6 +140,7 @@ def downsample_mask(image, factor):
137
  return downsampled
138
 
139
  def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
 
140
  def image_gen(prompt, num_steps):
141
  lock.acquire()
142
  try:
@@ -151,6 +155,7 @@ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, mean
151
  lock.release()
152
  return images.images[0]
153
 
 
154
  def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image):
155
  block = block_str.split(" ")[0]
156
 
@@ -184,6 +189,7 @@ def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, mean
184
  lock.release()
185
  return image
186
 
 
187
  def feature_icon(block_str, brush_index):
188
  block = block_str.split(" ")[0]
189
  if block in ["mid.0", "up.0.0"]:
 
10
  from matplotlib.colors import ListedColormap
11
  from utils import add_feature_on_area, replace_with_feature
12
  import threading
13
+ import spaces
14
 
15
  code_to_block = {
16
  "down.2.1": "unet.down_blocks.2.attentions.1",
 
20
  }
21
  lock = threading.Lock()
22
 
23
+ @spaces.GPU
24
  def process_cache(cache, saes_dict):
25
 
26
  top_features_dict = {}
 
67
 
68
 
69
  def create_prompt_part(pipe, saes_dict, demo):
70
+ @spaces.GPU
71
  def image_gen(prompt):
72
  lock.acquire()
73
  try:
 
140
  return downsampled
141
 
142
  def create_intervene_part(pipe: HookedStableDiffusionXLPipeline, saes_dict, means_dict, demo):
143
+ @spaces.GPU
144
  def image_gen(prompt, num_steps):
145
  lock.acquire()
146
  try:
 
155
  lock.release()
156
  return images.images[0]
157
 
158
+ @spaces.GPU
159
  def image_mod(prompt, block_str, brush_index, strength, num_steps, input_image):
160
  block = block_str.split(" ")[0]
161
 
 
189
  lock.release()
190
  return image
191
 
192
+ @spaces.GPU
193
  def feature_icon(block_str, brush_index):
194
  block = block_str.split(" ")[0]
195
  if block in ["mid.0", "up.0.0"]: