sfeucht commited on
Commit
f880b38
1 Parent(s): f0f61ae

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +6 -1
README.md CHANGED
@@ -5,7 +5,7 @@ Linear probe checkpoints for https://footprints.baulab.info
5
 
6
  To load a Llama-2-7b checkpoint at layer 0 and target index -3:
7
 
8
- ```
9
  import torch
10
  import torch.nn as nn
11
  from huggingface_hub import hf_hub_download
@@ -18,10 +18,15 @@ class LinearModel(nn.Module):
18
  output = self.fc(x)
19
  return output
20
 
 
 
21
  checkpoint_path = hf_hub_download(
22
  repo_id="sfeucht/footprints",
23
  filename="llama-2-7b/layer0_tgtidx-3.ckpt"
24
  )
 
 
 
25
  probe = LinearModel(4096, 32000)
26
  probe.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
27
  ```
 
5
 
6
  To load a Llama-2-7b checkpoint at layer 0 and target index -3:
7
 
8
+ ```python
9
  import torch
10
  import torch.nn as nn
11
  from huggingface_hub import hf_hub_download
 
18
  output = self.fc(x)
19
  return output
20
 
21
+ # example: llama-2-7b probe at layer 0, predicting 3 tokens ago
22
+ # predicting the next token would be `layer0_tgtidx1.ckpt`
23
  checkpoint_path = hf_hub_download(
24
  repo_id="sfeucht/footprints",
25
  filename="llama-2-7b/layer0_tgtidx-3.ckpt"
26
  )
27
+
28
+ # model_size is 4096 for both models.
29
+ # vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b
30
  probe = LinearModel(4096, 32000)
31
  probe.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
32
  ```