imone commited on
Commit
0a61848
1 Parent(s): 025fef8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +50 -1
README.md CHANGED
@@ -12,4 +12,53 @@ The original Llama 3 8b (base) special token weights are zero, which might cause
12
  <|end_header_id|>
13
  ```
14
 
15
- We set the weights of these tokens in `embed` and `lm_head` to be the mean of all other tokens.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  <|end_header_id|>
13
  ```
14
 
15
+ We set the weights of these tokens in `embed` and `lm_head` to be the mean of all other tokens.
16
+
17
+ Code for making this model:
18
+
19
+ ```python
20
+ import argparse
21
+
22
+ import transformers
23
+ import torch
24
+
25
+
26
+ def init_eot_embedding_llama3(model_path, output_dir, special_tokens=["<|eot_id|>", "<|start_header_id|>", "<|end_header_id|>"], mean_cutoff=128000, dtype=torch.bfloat16):
27
+ tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
28
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=dtype)
29
+
30
+ assert model.model.embed_tokens.weight.shape[0] >= mean_cutoff
31
+ assert model.lm_head.weight.shape[0] >= mean_cutoff
32
+
33
+ with torch.no_grad():
34
+ for token in special_tokens:
35
+ token_id = tokenizer.convert_tokens_to_ids(token)
36
+
37
+ print (f"Token {token} ID {token_id}")
38
+
39
+ model.model.embed_tokens.weight[token_id] = torch.mean(model.model.embed_tokens.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
40
+ model.lm_head.weight[token_id] = torch.mean(model.lm_head.weight[:mean_cutoff].to(torch.float32), dim=0).to(dtype)
41
+
42
+ # Save
43
+ tokenizer.save_pretrained(output_dir)
44
+ model.save_pretrained(output_dir)
45
+
46
+
47
+ def main():
48
+ parser = argparse.ArgumentParser()
49
+ parser.add_argument(
50
+ "--model-path",
51
+ help="Location of model, or HuggingFace repo ID",
52
+ )
53
+ parser.add_argument(
54
+ "--output-dir",
55
+ help="Location to write resulting model and tokenizer",
56
+ )
57
+
58
+ init_eot_embedding_llama3(**vars(parser.parse_args()))
59
+
60
+
61
+ if __name__ == "__main__":
62
+ main()
63
+
64
+ ```