File size: 2,111 Bytes
4813b63
 
 
 
 
 
 
3020dd6
4813b63
3020dd6
4813b63
3020dd6
4813b63
3020dd6
4813b63
 
 
3020dd6
4813b63
3020dd6
4813b63
 
 
3020dd6
 
4813b63
 
 
 
 
 
3020dd6
4813b63
 
 
 
 
 
 
 
 
 
3020dd6
4813b63
 
 
 
 
3020dd6
4813b63
 
 
3020dd6
4813b63
 
 
3020dd6
4813b63
 
3020dd6
4813b63
 
 
 
 
 
 
3020dd6
4813b63
3020dd6
4813b63
3020dd6
 
4813b63
3020dd6
4813b63
 
 
3020dd6
4813b63
3020dd6
4813b63
 
 
 
 
 
3020dd6
4813b63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
---

library_name: transformers
license: apache-2.0
language:
- ja
- en
---


## Model Merge

Gakki-7B was build by [Chat Vector](https://arxiv.org/abs/2310.04799)

A recipe shows as below

```
Rakuten/RakutenAI-7B-instruct + (prometheus-eval/prometheus-7b-v2.0 - mistralai/Mistral-7B-Instruct-v0.2)
```

## Source Code

```python
import torch
from transformers import AutoModelForCausalLM


def build_chat_vector_model(
    base_model_name,
    inst_model_name,
    target_model_name,
    skip_layers,
    ):

    base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )
    inst_model = AutoModelForCausalLM.from_pretrained(
        inst_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cpu",
    )

    target_model = AutoModelForCausalLM.from_pretrained(
        target_model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    )

    # 英語ベースモデル
    for k, v in base_model.state_dict().items():
        print(k, v.shape)

    # 日本語継続事前学習モデル
    for k, v in target_model.state_dict().items():
        print(k, v.shape)

    # 除外対象
    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    for k, v in target_model.state_dict().items():
        # layernormも除外
        if (k in skip_layers) or ("layernorm" in k):
            continue
        chat_vector = inst_model.state_dict()[k] - base_model.state_dict()[k]
        new_v = v + chat_vector.to(v.device)
        v.copy_(new_v)

    target_model.save_pretrained("./Gakki-7B")

    return


if __name__ == '__main__':

    base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    inst_model_name = "prometheus-eval/prometheus-7b-v2.0"
    target_model_name = "Rakuten/RakutenAI-7B-instruct"

    skip_layers = ["model.embed_tokens.weight", "lm_head.weight"]

    build_chat_vector_model(
        base_model_name=base_model_name,
        inst_model_name=inst_model_name,
        target_model_name=target_model_name,
        skip_layers=skip_layers
    )

```