Spaces:
Running
on
Zero
Running
on
Zero
Update utils/hooks.py
Browse files- utils/hooks.py +6 -0
utils/hooks.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
import torch
|
|
|
2 |
|
|
|
3 |
@torch.no_grad()
|
4 |
def add_feature(sae, feature_idx, value, module, input, output):
|
5 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -10,6 +12,7 @@ def add_feature(sae, feature_idx, value, module, input, output):
|
|
10 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
11 |
|
12 |
|
|
|
13 |
@torch.no_grad()
|
14 |
def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
|
15 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -22,6 +25,7 @@ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output)
|
|
22 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
23 |
|
24 |
|
|
|
25 |
@torch.no_grad()
|
26 |
def replace_with_feature(sae, feature_idx, value, module, input, output):
|
27 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -32,6 +36,7 @@ def replace_with_feature(sae, feature_idx, value, module, input, output):
|
|
32 |
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
33 |
|
34 |
|
|
|
35 |
@torch.no_grad()
|
36 |
def reconstruct_sae_hook(sae, module, input, output):
|
37 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
@@ -40,6 +45,7 @@ def reconstruct_sae_hook(sae, module, input, output):
|
|
40 |
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
41 |
|
42 |
|
|
|
43 |
@torch.no_grad()
|
44 |
def ablate_block(module, input, output):
|
45 |
return input
|
|
|
1 |
import torch
|
2 |
+
import spaces
|
3 |
|
4 |
+
@spaces.GPU
|
5 |
@torch.no_grad()
|
6 |
def add_feature(sae, feature_idx, value, module, input, output):
|
7 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
12 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
13 |
|
14 |
|
15 |
+
@spaces.GPU
|
16 |
@torch.no_grad()
|
17 |
def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
|
18 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
25 |
return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
26 |
|
27 |
|
28 |
+
@spaces.GPU
|
29 |
@torch.no_grad()
|
30 |
def replace_with_feature(sae, feature_idx, value, module, input, output):
|
31 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
36 |
return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
|
37 |
|
38 |
|
39 |
+
@spaces.GPU
|
40 |
@torch.no_grad()
|
41 |
def reconstruct_sae_hook(sae, module, input, output):
|
42 |
diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
|
|
|
45 |
return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
|
46 |
|
47 |
|
48 |
+
@spaces.GPU
|
49 |
@torch.no_grad()
|
50 |
def ablate_block(module, input, output):
|
51 |
return input
|