surokpro2 commited on
Commit
3cb415b
1 Parent(s): 79ab360
Files changed (4) hide show
  1. SAE/sae.py +0 -1
  2. SDLens/hooked_sd_pipeline.py +0 -1
  3. app.py +5 -1
  4. utils/hooks.py +0 -5
SAE/sae.py CHANGED
@@ -105,7 +105,6 @@ class SparseAutoencoder(nn.Module):
105
  def n_dirs(self):
106
  return self.n_dirs_local
107
 
108
- @spaces.GPU
109
  def encode(self, x):
110
  x = x.to('cuda') - self.pre_bias
111
  latents_pre_act = self.encoder(x) + self.latent_bias
 
105
  def n_dirs(self):
106
  return self.n_dirs_local
107
 
 
108
  def encode(self, x):
109
  x = x.to('cuda') - self.pre_bias
110
  latents_pre_act = self.encoder(x) + self.latent_bias
SDLens/hooked_sd_pipeline.py CHANGED
@@ -309,7 +309,6 @@ class HookedDiffusionAbstractPipeline:
309
  def __setattr__(self, name, value):
310
  return setattr(self.pipe, name, value)
311
 
312
- @spaces.GPU
313
  def __call__(self, *args, **kwargs):
314
  return self.pipe(*args, **kwargs)
315
 
 
309
  def __setattr__(self, name, value):
310
  return setattr(self.pipe, name, value)
311
 
 
312
  def __call__(self, *args, **kwargs):
313
  return self.pipe(*args, **kwargs)
314
 
app.py CHANGED
@@ -20,7 +20,7 @@ code_to_block = {
20
  }
21
  lock = threading.Lock()
22
 
23
- @spaces.GPU
24
  def process_cache(cache, saes_dict):
25
 
26
  top_features_dict = {}
@@ -70,6 +70,9 @@ def create_prompt_part(pipe, saes_dict, demo):
70
  @spaces.GPU
71
  def image_gen(prompt):
72
  lock.acquire()
 
 
 
73
  try:
74
  images, cache = pipe.run_with_cache(
75
  prompt,
@@ -376,6 +379,7 @@ if __name__ == "__main__":
376
  variant=("fp16" if dtype==torch.float16 else None)
377
  )
378
  pipe.set_progress_bar_config(disable=True)
 
379
 
380
  path_to_checkpoints = './checkpoints/'
381
 
 
20
  }
21
  lock = threading.Lock()
22
 
23
+
24
  def process_cache(cache, saes_dict):
25
 
26
  top_features_dict = {}
 
70
  @spaces.GPU
71
  def image_gen(prompt):
72
  lock.acquire()
73
+ pipe.to('cuda')
74
+ for sae in saes_dict.values:
75
+ sae.to('cuda')
76
  try:
77
  images, cache = pipe.run_with_cache(
78
  prompt,
 
379
  variant=("fp16" if dtype==torch.float16 else None)
380
  )
381
  pipe.set_progress_bar_config(disable=True)
382
+ pipe.to('cuda')
383
 
384
  path_to_checkpoints = './checkpoints/'
385
 
utils/hooks.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  import spaces
3
 
4
- @spaces.GPU
5
  @torch.no_grad()
6
  def add_feature(sae, feature_idx, value, module, input, output):
7
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -12,7 +11,6 @@ def add_feature(sae, feature_idx, value, module, input, output):
12
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
13
 
14
 
15
- @spaces.GPU
16
  @torch.no_grad()
17
  def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
18
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -25,7 +23,6 @@ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output)
25
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
26
 
27
 
28
- @spaces.GPU
29
  @torch.no_grad()
30
  def replace_with_feature(sae, feature_idx, value, module, input, output):
31
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -36,7 +33,6 @@ def replace_with_feature(sae, feature_idx, value, module, input, output):
36
  return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
37
 
38
 
39
- @spaces.GPU
40
  @torch.no_grad()
41
  def reconstruct_sae_hook(sae, module, input, output):
42
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -45,7 +41,6 @@ def reconstruct_sae_hook(sae, module, input, output):
45
  return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
46
 
47
 
48
- @spaces.GPU
49
  @torch.no_grad()
50
  def ablate_block(module, input, output):
51
  return input
 
1
  import torch
2
  import spaces
3
 
 
4
  @torch.no_grad()
5
  def add_feature(sae, feature_idx, value, module, input, output):
6
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
11
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
12
 
13
 
 
14
  @torch.no_grad()
15
  def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
16
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
23
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
24
 
25
 
 
26
  @torch.no_grad()
27
  def replace_with_feature(sae, feature_idx, value, module, input, output):
28
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
33
  return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
34
 
35
 
 
36
  @torch.no_grad()
37
  def reconstruct_sae_hook(sae, module, input, output):
38
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
41
  return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
42
 
43
 
 
44
  @torch.no_grad()
45
  def ablate_block(module, input, output):
46
  return input