|
---
|
|
library_name: transformers
|
|
license: apache-2.0
|
|
language:
|
|
- ja
|
|
- en
|
|
---
|
|
|
|
## Model Merge |
|
|
|
Gakki-7B was build by [Chat Vector](https://arxiv.org/abs/2310.04799) |
|
|
|
A recipe shows as below |
|
|
|
``` |
|
Rakuten/RakutenAI-7B-instruct + (prometheus-eval/prometheus-7b-v2.0 - mistralai/Mistral-7B-Instruct-v0.2) |
|
``` |
|
|
|
## Source Code |
|
|
|
```python |
|
import torch |
|
from transformers import AutoModelForCausalLM |
|
|
|
|
|
def build_chat_vector_model( |
|
base_model_name, |
|
inst_model_name, |
|
target_model_name, |
|
skip_layers, |
|
): |
|
|
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
base_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
) |
|
inst_model = AutoModelForCausalLM.from_pretrained( |
|
inst_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cpu", |
|
) |
|
|
|
target_model = AutoModelForCausalLM.from_pretrained( |
|
target_model_name, |
|
torch_dtype=torch.bfloat16, |
|
device_map="cuda", |
|
) |
|
|
|
# 英語ベースモデル |
|
for k, v in base_model.state_dict().items(): |
|
print(k, v.shape) |
|
|
|
# 日本語継続事前学習モデル |
|
for k, v in target_model.state_dict().items(): |
|
print(k, v.shape) |
|
|
|
# 除外対象 |
|
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"] |
|
|
|
for k, v in target_model.state_dict().items(): |
|
# layernormも除外 |
|
if (k in skip_layers) or ("layernorm" in k): |
|
continue |
|
chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k] |
|
new_v = v + chat_vector.to(v.device) |
|
v.copy_(new_v) |
|
|
|
target_model.save_pretrained("./Gakki-7B") |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2" |
|
inst_model_name = "prometheus-eval/prometheus-7b-v2.0" |
|
target_model_name = "Rakuten/RakutenAI-7B-instruct" |
|
|
|
skip_layers = ["model.embed_tokens.weight", "lm_head.weight"] |
|
|
|
build_chat_vector_model( |
|
base_model_name=base_model_name, |
|
inst_model_name=inst_model_name, |
|
target_model_name=target_model_name, |
|
skip_layers=skip_layers |
|
) |
|
|
|
``` |