Spaces:
Running
on
Zero
Running
on
Zero
fixing
Browse files- SAE/sae.py +0 -1
- SDLens/hooked_sd_pipeline.py +0 -1
- app.py +5 -1
- utils/hooks.py +0 -5
SAE/sae.py
CHANGED
@@ -105,7 +105,6 @@ class SparseAutoencoder(nn.Module):
|
|
105 |
def n_dirs(self):
|
106 |
return self.n_dirs_local
|
107 |
|
108 |
-
@spaces.GPU
|
109 |
def encode(self, x):
|
110 |
x = x.to('cuda') - self.pre_bias
|
111 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
|
|
105 |
def n_dirs(self):
|
106 |
return self.n_dirs_local
|
107 |
|
|
|
108 |
def encode(self, x):
|
109 |
x = x.to('cuda') - self.pre_bias
|
110 |
latents_pre_act = self.encoder(x) + self.latent_bias
|
SDLens/hooked_sd_pipeline.py
CHANGED
@@ -309,7 +309,6 @@ class HookedDiffusionAbstractPipeline:
|
|
309 |
def __setattr__(self, name, value):
|
310 |
return setattr(self.pipe, name, value)
|
311 |
|
312 |
-
@spaces.GPU
|
313 |
def __call__(self, *args, **kwargs):
|
314 |
return self.pipe(*args, **kwargs)
|
315 |
|
|
|
309 |
def __setattr__(self, name, value):
|
310 |
return setattr(self.pipe, name, value)
|
311 |
|
|
|
312 |
def __call__(self, *args, **kwargs):
|
313 |
return self.pipe(*args, **kwargs)
|
314 |
|
app.py
CHANGED
@@ -20,7 +20,7 @@ code_to_block = {
|
|
20 |
}
|
21 |
lock = threading.Lock()
|
22 |
|
23 |
-
|
24 |
def process_cache(cache, saes_dict):
|
25 |
|
26 |
top_features_dict = {}
|
@@ -70,6 +70,9 @@ def create_prompt_part(pipe, saes_dict, demo):
|
|
70 |
@spaces.GPU
|
71 |
def image_gen(prompt):
|
72 |
lock.acquire()
|
|
|
|
|
|
|
73 |
try:
|
74 |
images, cache = pipe.run_with_cache(
|
75 |
prompt,
|
@@ -376,6 +379,7 @@ if __name__ == "__main__":
|
|
376 |
variant=("fp16" if dtype==torch.float16 else None)
|
377 |
)
|
378 |
pipe.set_progress_bar_config(disable=True)
|
|
|
379 |
|
380 |
path_to_checkpoints = './checkpoints/'
|
381 |
|
|
|
20 |
}
|
21 |
lock = threading.Lock()
|
22 |
|
23 |
+
|
24 |
def process_cache(cache, saes_dict):
|
25 |
|
26 |
top_features_dict = {}
|
|
|
70 |
@spaces.GPU
|
71 |
def image_gen(prompt):
|
72 |
lock.acquire()
|
73 |
+
pipe.to('cuda')
|
74 |
+
for sae in saes_dict.values:
|
75 |
+
sae.to('cuda')
|
76 |
try:
|
77 |
images, cache = pipe.run_with_cache(
|
78 |
prompt,
|
|
|
379 |
variant=("fp16" if dtype==torch.float16 else None)
|
380 |
)
|
381 |
pipe.set_progress_bar_config(disable=True)
|
382 |
+
pipe.to('cuda')
|
383 |
|
384 |
path_to_checkpoints = './checkpoints/'
|
385 |
|
utils/hooks.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import torch
|
2 |
import spaces
|
3 |
|
4 |
-
@spaces.GPU
|
5 |
@torch.no_grad()
|
6 |
def add_feature(sae, feature_idx, value, module, input, output):
|
7 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -12,7 +11,6 @@ def add_feature(sae, feature_idx, value, module, input, output):
|
|
12 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
13 |
|
14 |
|
15 |
-
@spaces.GPU
|
16 |
@torch.no_grad()
|
17 |
def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
|
18 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -25,7 +23,6 @@ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output)
|
|
25 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
26 |
|
27 |
|
28 |
-
@spaces.GPU
|
29 |
@torch.no_grad()
|
30 |
def replace_with_feature(sae, feature_idx, value, module, input, output):
|
31 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -36,7 +33,6 @@ def replace_with_feature(sae, feature_idx, value, module, input, output):
|
|
36 |
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
37 |
|
38 |
|
39 |
-
@spaces.GPU
|
40 |
@torch.no_grad()
|
41 |
def reconstruct_sae_hook(sae, module, input, output):
|
42 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -45,7 +41,6 @@ def reconstruct_sae_hook(sae, module, input, output):
|
|
45 |
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
46 |
|
47 |
|
48 |
-
@spaces.GPU
|
49 |
@torch.no_grad()
|
50 |
def ablate_block(module, input, output):
|
51 |
return input
|
|
|
1 |
import torch
|
2 |
import spaces
|
3 |
|
|
|
4 |
@torch.no_grad()
|
5 |
def add_feature(sae, feature_idx, value, module, input, output):
|
6 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
11 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
12 |
|
13 |
|
|
|
14 |
@torch.no_grad()
|
15 |
def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
|
16 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
23 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
24 |
|
25 |
|
|
|
26 |
@torch.no_grad()
|
27 |
def replace_with_feature(sae, feature_idx, value, module, input, output):
|
28 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
33 |
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
34 |
|
35 |
|
|
|
36 |
@torch.no_grad()
|
37 |
def reconstruct_sae_hook(sae, module, input, output):
|
38 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
41 |
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
42 |
|
43 |
|
|
|
44 |
@torch.no_grad()
|
45 |
def ablate_block(module, input, output):
|
46 |
return input
|