Update README.md
Browse files
README.md
CHANGED
@@ -5,7 +5,7 @@ Linear probe checkpoints for https://footprints.baulab.info
|
|
5 |
|
6 |
To load a Llama-2-7b checkpoint at layer 0 and target index -3:
|
7 |
|
8 |
-
```
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
from huggingface_hub import hf_hub_download
|
@@ -18,10 +18,15 @@ class LinearModel(nn.Module):
|
|
18 |
output = self.fc(x)
|
19 |
return output
|
20 |
|
|
|
|
|
21 |
checkpoint_path = hf_hub_download(
|
22 |
repo_id="sfeucht/footprints",
|
23 |
filename="llama-2-7b/layer0_tgtidx-3.ckpt"
|
24 |
)
|
|
|
|
|
|
|
25 |
probe = LinearModel(4096, 32000)
|
26 |
probe.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
27 |
```
|
|
|
5 |
|
6 |
To load a Llama-2-7b checkpoint at layer 0 and target index -3:
|
7 |
|
8 |
+
```python
|
9 |
import torch
|
10 |
import torch.nn as nn
|
11 |
from huggingface_hub import hf_hub_download
|
|
|
18 |
output = self.fc(x)
|
19 |
return output
|
20 |
|
21 |
+
# example: llama-2-7b probe at layer 0, predicting 3 tokens ago
|
22 |
+
# predicting the next token would be `layer0_tgtidx1.ckpt`
|
23 |
checkpoint_path = hf_hub_download(
|
24 |
repo_id="sfeucht/footprints",
|
25 |
filename="llama-2-7b/layer0_tgtidx-3.ckpt"
|
26 |
)
|
27 |
+
|
28 |
+
# model_size is 4096 for both models.
|
29 |
+
# vocab_size is 32000 for Llama-2-7b and 128256 for Llama-3-8b
|
30 |
probe = LinearModel(4096, 32000)
|
31 |
probe.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
32 |
```
|