File size: 2,866 Bytes
ba6ae78 18e7e04 ba6ae78 18e7e04 3afe408 ba6ae78 df99996 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
---
license: cc-by-nc-4.0
---
## License
非商用ライセンスで公開します。
## Chat Vector
```
Tora-7B-v0.2 = NTQAI/chatntq-ja-7b-v1.0 + (NousResearch/Hermes-2-Pro-Mistral-7B - mistralai/Mistral-7B-v0.1)
```
## 実装
@jovyan様の実装を参考に下記のコードでモデルを作成しました。
```python
import torch
from transformers import AutoModelForCausalLM
def build_chat_vector_model(
base_model_name,
inst_model_name,
target_model_name,
skip_layers,
):
base_model = AutoModelForCausalLM.from_pretrained(
base_model_name,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
inst_model = AutoModelForCausalLM.from_pretrained(
inst_model_name,
torch_dtype=torch.bfloat16,
device_map="cpu",
)
target_model = AutoModelForCausalLM.from_pretrained(
target_model_name,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
# 英語ベースモデル
for k, v in base_model.state_dict().items():
print(k, v.shape)
# 日本語継続事前学習モデル
for k, v in target_model.state_dict().items():
print(k, v.shape)
# 除外対象
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
for k, v in target_model.state_dict().items():
# layernormも除外
if (k in skip_layers) or ("layernorm" in k):
continue
chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
new_v = v + chat_vector.to(v.device)
v.copy_(new_v)
target_model.save_pretrained("./chat_model")
return
if __name__ == '__main__':
base_model_name = "mistralai/Mistral-7B-v0.1"
inst_model_name = "NousResearch/Hermes-2-Pro-Mistral-7B"
target_model_name = "NTQAI/chatntq-ja-7b-v1.0"
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]
build_chat_vector_model(
base_model_name=base_model_name,
inst_model_name=inst_model_name,
target_model_name=target_model_name,
skip_layers=skip_layers
)
```
## Benchmark (Japanese MT bench)
|model|category|score|ver|
|:---|:---|:---|:---|
|Tora-7B-v0.2|Writing|3.8|single-turn|
|Tora-7B-v0.2|Roleplay|7.1|single-turn|
|Tora-7B-v0.2|Reasoning|6.3|single-turn|
|Tora-7B-v0.2|Math|3.0|single-turn|
|Tora-7B-v0.2|Coding|2.2|single-turn|
|Tora-7B-v0.2|Extraction|6.6|single-turn|
|Tora-7B-v0.2|STEM|7.2|single-turn|
|Tora-7B-v0.2|Humanities|8.2|single-turn|
![image/png](https://cdn-uploads.huggingface.co/production/uploads/651e3f30ca333f3c8df692b8/_CBS90NRrYUMXzsFC1LIV.png)
## 謝辞
ChatVectorの記事を執筆してくださった@jovyan様に深くお礼申し上げます。
## 参考
[Chat Vectorを使って日本語LLMをチャットモデルに改造する](https://qiita.com/jovyan/items/ee6affa5ee5bdaada6b4) |