File size: 1,651 Bytes
f685e32
 
 
9ae5ab6
 
aabc37c
9ae5ab6
 
 
aabc37c
9ae5ab6
 
 
 
 
 
 
 
 
c77a920
9ae5ab6
 
 
 
 
 
 
 
 
 
 
 
c77a920
 
 
 
 
 
 
 
9ae5ab6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
license: apache-2.0
---
# Dataset

Japanese subset of the [mC4](https://huggingface.co/datasets/mc4) dataset

# Training

Trained for 3000 steps on top of the MPT 7b checkpoint [mosaicml/mpt-7b](https://huggingface.co/mosaicml/mpt-7b)

# How to load

Before running this model, please install the following pip package:

```bash
pip install einops
```

To load the model, run the following command.

```python
from transformers import AutoModelForCausalLM

model_name = "lightblue/japanese-mpt-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    trust_remote_code=True
)
```

To run this model, you may need to load it in a lower precision in order for it to fit onto your GPU. We found for a T4 GPU, it requires loading the model in 8-bit precision. To load the model in 8-bit, please install the following pip packages:

```bash
pip install bitsandbytes accelerate
```

Caution - you will also need enough RAM to load the model. We estimate loading this model requires ~30GB.

<details>
<summary><b>In 8 bit</b></summary>



```python
from transformers import AutoModelForCausalLM

model_name = "lightblue/japanese-mpt-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype='auto',
    load_in_8bit=True,
    trust_remote_code=True
)
```

</details>


# How to use
```python
from transformers import AutoTokenizer, pipeline

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

pipe("こんにちは", temperature=0, do_sample=False, return_full_text=False, max_new_tokens=32)
```