|
--- |
|
language: ja |
|
thumbnail: https://github.com/rinnakk/japanese-gpt2/blob/master/rinna.png |
|
tags: |
|
- ja |
|
- japanese |
|
- roberta |
|
- masked-lm |
|
- nlp |
|
license: mit |
|
datasets: |
|
- cc100 |
|
- wikipedia |
|
--- |
|
|
|
# japanese-roberta-base |
|
|
|
![rinna-icon](./rinna.png) |
|
|
|
This repository provides a base-sized Japanese RoBERTa model. The model is provided by [rinna](https://corp.rinna.co.jp/). |
|
|
|
# How to use the model |
|
|
|
*NOTE:* Use `T5Tokenizer` to initiate the tokenizer. |
|
|
|
~~~~ |
|
from transformers import T5Tokenizer, RobertaForMaskedLM |
|
|
|
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base") |
|
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading |
|
|
|
model = RobertaForMaskedLM.from_pretrained("rinna/japanese-roberta-base") |
|
~~~~ |
|
|
|
# How to use the model for masked token prediction |
|
|
|
*NOTE:* To predict a masked token, be sure to add a `[CLS]` token before the sentence for the model to correctly encode it, as it is used during the model training. |
|
|
|
Here we adopt the example by [kenta1984](https://qiita.com/kenta1984/items/7f3a5d859a15b20657f3#%E6%97%A5%E6%9C%AC%E8%AA%9Epre-trained-models) to illustrate how our model works as a masked language model. |
|
|
|
~~~~ |
|
# original text |
|
text = "テレビでサッカーの試合を見る。" |
|
|
|
# prepend [CLS] |
|
text = "[CLS]" + text |
|
|
|
# tokenize |
|
tokens = tokenizer.tokenize(text) |
|
print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', 'サッカー', 'の試合', 'を見る', '。'] |
|
|
|
# mask a token |
|
masked_idx = 4 |
|
tokens[masked_idx] = tokenizer.mask_token |
|
print(tokens) # output: ['[CLS]', '▁', 'テレビ', 'で', '[MASK]', 'の試合', 'を見る', '。'] |
|
|
|
# convert to ids |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
print(token_ids) # output: [4, 9, 480, 19, 6, 8466, 6518, 8] |
|
|
|
# convert to tensor |
|
import torch |
|
token_tensor = torch.tensor([token_ids]) |
|
|
|
# get the top 50 predictions of the masked token |
|
model = model.eval() |
|
with torch.no_grad(): |
|
outputs = model(token_tensor) |
|
predictions = outputs[0][0, masked_idx].topk(100) |
|
for i, index_t in enumerate(predictions.indices): |
|
index = index_t.item() |
|
token = tokenizer.convert_ids_to_tokens([index])[0] |
|
print(i, token) |
|
|
|
""" |
|
0 サッカー |
|
1 メジャーリーグ |
|
2 フィギュアスケート |
|
3 バレーボール |
|
4 ボクシング |
|
5 ワールドカップ |
|
6 バスケットボール |
|
7 阪神タイガース |
|
8 プロ野球 |
|
9 アメリカンフットボール |
|
10 日本代表 |
|
11 高校野球 |
|
12 福岡ソフトバンクホークス |
|
13 プレミアリーグ |
|
14 ファイターズ |
|
15 ラグビー |
|
16 東北楽天ゴールデンイーグルス |
|
17 中日ドラゴンズ |
|
18 アイスホッケー |
|
19 フットサル |
|
20 サッカー選手 |
|
21 スポーツ |
|
22 チャンピオンズリーグ |
|
23 ジャイアンツ |
|
24 ソフトボール |
|
25 バスケット |
|
26 フットボール |
|
27 新日本プロレス |
|
28 バドミントン |
|
29 千葉ロッテマリーンズ |
|
30 <unk> |
|
31 北京オリンピック |
|
32 広島東洋カープ |
|
33 キックボクシング |
|
34 オリンピック |
|
35 ロンドンオリンピック |
|
36 読売ジャイアンツ |
|
37 テニス |
|
38 東京オリンピック |
|
39 日本シリーズ |
|
40 ヤクルトスワローズ |
|
41 タイガース |
|
42 サッカークラブ |
|
43 ハンドボール |
|
44 野球 |
|
45 バルセロナ |
|
46 ホッケー |
|
47 格闘技 |
|
48 大相撲 |
|
49 ブンデスリーガ |
|
50 スキージャンプ |
|
51 プロサッカー選手 |
|
52 ヤンキース |
|
53 社会人野球 |
|
54 クライマックスシリーズ |
|
55 クリケット |
|
56 トップリーグ |
|
57 パラリンピック |
|
58 クラブチーム |
|
59 ニュージーランド |
|
60 総合格闘技 |
|
61 ウィンブルドン |
|
62 ドラゴンボール |
|
63 レスリング |
|
64 ドラゴンズ |
|
65 プロ野球選手 |
|
66 リオデジャネイロオリンピック |
|
67 ホークス |
|
68 全日本プロレス |
|
69 プロレス |
|
70 ヴェルディ |
|
71 都市対抗野球 |
|
72 ライオンズ |
|
73 グランプリシリーズ |
|
74 日本プロ野球 |
|
75 アテネオリンピック |
|
76 ヤクルト |
|
77 イーグルス |
|
78 巨人 |
|
79 ワールドシリーズ |
|
80 アーセナル |
|
81 マスターズ |
|
82 ソフトバンク |
|
83 日本ハム |
|
84 クロアチア |
|
85 マリナーズ |
|
86 サッカーリーグ |
|
87 アトランタオリンピック |
|
88 ゴルフ |
|
89 ジャニーズ |
|
90 甲子園 |
|
91 夏の甲子園 |
|
92 陸上競技 |
|
93 ベースボール |
|
94 卓球 |
|
95 プロ |
|
96 南アフリカ |
|
97 レッズ |
|
98 ウルグアイ |
|
99 オールスターゲーム |
|
""" |
|
~~~~ |
|
|
|
|
|
# Model architecture |
|
A 12-layer, 768-hidden-size transformer-based masked language model. |
|
|
|
# Training |
|
The model was trained on [Japanese CC-100](http://data.statmt.org/cc-100/ja.txt.xz) and [Japanese Wikipedia](https://dumps.wikimedia.org/jawiki/) to optimize a masked language modelling objective on 8*V100 GPUs for around 15 days. It reaches ~3.9 perplexity on a dev set sampled from CC-100. |
|
|
|
# Tokenization |
|
The model uses a [sentencepiece](https://github.com/google/sentencepiece)-based tokenizer, the vocabulary was trained on the Japanese Wikipedia using the official sentencepiece training script. |
|
|
|
# Licenese |
|
[The MIT license](https://opensource.org/licenses/MIT) |
|
|