surokpro2 commited on
Commit
c8a2456
1 Parent(s): 46611c9

Update utils/hooks.py

Browse files
Files changed (1) hide show
  1. utils/hooks.py +6 -0
utils/hooks.py CHANGED
@@ -1,5 +1,7 @@
1
  import torch
 
2
 
 
3
  @torch.no_grad()
4
  def add_feature(sae, feature_idx, value, module, input, output):
5
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -10,6 +12,7 @@ def add_feature(sae, feature_idx, value, module, input, output):
10
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
11
 
12
 
 
13
  @torch.no_grad()
14
  def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
15
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -22,6 +25,7 @@ def add_feature_on_area(sae, feature_idx, activation_map, module, input, output)
22
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
23
 
24
 
 
25
  @torch.no_grad()
26
  def replace_with_feature(sae, feature_idx, value, module, input, output):
27
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -32,6 +36,7 @@ def replace_with_feature(sae, feature_idx, value, module, input, output):
32
  return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
33
 
34
 
 
35
  @torch.no_grad()
36
  def reconstruct_sae_hook(sae, module, input, output):
37
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
@@ -40,6 +45,7 @@ def reconstruct_sae_hook(sae, module, input, output):
40
  return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
41
 
42
 
 
43
  @torch.no_grad()
44
  def ablate_block(module, input, output):
45
  return input
 
1
  import torch
2
+ import spaces
3
 
4
+ @spaces.GPU
5
  @torch.no_grad()
6
  def add_feature(sae, feature_idx, value, module, input, output):
7
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
12
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
13
 
14
 
15
+ @spaces.GPU
16
  @torch.no_grad()
17
  def add_feature_on_area(sae, feature_idx, activation_map, module, input, output):
18
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
25
  return (output[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
26
 
27
 
28
+ @spaces.GPU
29
  @torch.no_grad()
30
  def replace_with_feature(sae, feature_idx, value, module, input, output):
31
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
36
  return (input[0] + to_add.permute(0, 3, 1, 2).to(output[0].device),)
37
 
38
 
39
+ @spaces.GPU
40
  @torch.no_grad()
41
  def reconstruct_sae_hook(sae, module, input, output):
42
  diff = (output[0] - input[0]).permute((0, 2, 3, 1)).to(sae.device)
 
45
  return (input[0] + reconstructed.permute(0, 3, 1, 2).to(output[0].device),)
46
 
47
 
48
+ @spaces.GPU
49
  @torch.no_grad()
50
  def ablate_block(module, input, output):
51
  return input