ryota39 commited on
Commit
a9aa1bf
1 Parent(s): a47b8cd

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -3
README.md CHANGED
@@ -1,3 +1,56 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - llm-jp/databricks-dolly-15k-ja
5
+ language:
6
+ - ja
7
+ library_name: transformers
8
+ ---
9
+ ## モデル
10
+
11
+ - ベースモデル:[llm-jp/llm-jp-1.3b-v1.0](https://huggingface.co/llm-jp/llm-jp-1.3b-v1.0)
12
+ - 学習データセット:[llm-jp/databricks-dolly-15k-ja](https://huggingface.co/datasets/llm-jp/databricks-dolly-15k-ja)
13
+ - 学習方式:フルパラメータチューニング
14
+
15
+ ## サンプル
16
+
17
+ ```python
18
+ import torch
19
+ from transformers import AutoTokenizer, AutoModelForCausalLM
20
+
21
+
22
+ tokenizer = AutoTokenizer.from_pretrained(
23
+ "ryota39/llm-jp-1b-sft-15k"
24
+ )
25
+ pad_token_id = tokenizer.pad_token_id
26
+
27
+ model = AutoModelForCausalLM.from_pretrained(
28
+ "ryota39/llm-jp-1b-sft-15k",
29
+ device_map="auto",
30
+ torch_dtype=torch.float16,
31
+ )
32
+
33
+ text = "東京の観光名所を教えてください。\n"
34
+ tokenized_input = tokenizer.encode(
35
+ text,
36
+ add_special_tokens=False,
37
+ return_tensors="pt"
38
+ ).to(model.device)
39
+
40
+ attention_mask = torch.ones_like(tokenized_input)
41
+ attention_mask[tokenized_input == pad_token_id] = 0
42
+
43
+ with torch.no_grad():
44
+ output = model.generate(
45
+ tokenized_input,
46
+ attention_mask=attention_mask,
47
+ max_new_tokens=128,
48
+ do_sample=False,
49
+ # top_p=0.8,
50
+ # temperature=0.8,
51
+ repetition_penalty=1.0
52
+ )[0]
53
+
54
+ print(tokenizer.decode(output))
55
+
56
+ ```