surokpro2 commited on
Commit
79ab360
1 Parent(s): c8eef65
Files changed (1) hide show
  1. SAE/sae.py +1 -4
SAE/sae.py CHANGED
@@ -107,10 +107,7 @@ class SparseAutoencoder(nn.Module):
107
 
108
  @spaces.GPU
109
  def encode(self, x):
110
-
111
- print(x.device)
112
- print(self.pre_bias.device)
113
- x = x - self.pre_bias
114
  latents_pre_act = self.encoder(x) + self.latent_bias
115
 
116
  vals, inds = torch.topk(
 
107
 
108
  @spaces.GPU
109
  def encode(self, x):
110
+ x = x.to('cuda') - self.pre_bias
 
 
 
111
  latents_pre_act = self.encoder(x) + self.latent_bias
112
 
113
  vals, inds = torch.topk(