File size: 4,650 Bytes
751936e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""

## 简介
- bert和clue词典比较 https://github.com/CLUEbenchmark/CLUECorpus2020#%E8%AF%8D%E8%A1%A8%E4%BB%8B%E7%BB%8D
- 相关issue: https://github.com/google-research/bert/issues/396
- bert中文词典大小21128(2万)
- 英文字母都小写了(有没有不小写的?)
-

args:
-
-
output:
-


python bpe_oov.py \
  --vocab-bpe vocab.google.txt \
  --inputs ../raw/discovery_all \
  --workers 60

# stderr打印在屏幕,stdout放在oov_lines
python bpe_oov.py \
  --vocab-bpe vocab.clue_plus.txt \
  --inputs ../raw/discovery_all \
  --workers 60 > oov_lines


python bpe_oov.py \
  --vocab-bpe vocab.clue_plus.txt \
  --inputs ../raw/small/jd.train.raw  \
  --workers 60 > oov_lines




## 整词

"""

import argparse
from transformers import BertTokenizer
import contextlib
import sys

from collections import defaultdict
from multiprocessing import Pool

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--vocab-bpe",
        type=str,
        help='path to vocab.bpe',
    )
    parser.add_argument(
        "--inputs",
        nargs="+",
        default=['-'],
        help="input files to filter/encode",
    )
    parser.add_argument("--workers", type=int, default=20)
    args = parser.parse_args()


    with contextlib.ExitStack() as stack:
        inputs = [
            stack.enter_context(open(input, "r", encoding="utf-8"))
            if input != "-" else sys.stdin
            for input in args.inputs
        ]

        encoder = MultiprocessingEncoder(args.vocab_bpe)
        pool = Pool(args.workers, initializer=encoder.initializer)
        oov_lines = pool.imap(encoder.get_oov_lines, zip(*inputs), 100)

        oov_count = defaultdict(int)
        for i, oov_line in enumerate(oov_lines, start=1):  # 主要的计算模块
            for oov in oov_line:
                oov_count[oov] += 1
            if i % 10000 == 0:
                print("processed {} lines".format(i), file=sys.stderr)
        sorted_oov = sorted(oov_count.items(), key=lambda kv:kv[1], reverse=True)

        with open('oov', 'w', encoding='utf-8') as f_out:
            f_out.write('\n'.join(['%s %d' % (k,v) for k, v in sorted_oov]))

class MultiprocessingEncoder(object):

    def __init__(self, vocab_bpe):
        self.vocab_bpe = vocab_bpe

    def initializer(self):   # 为啥不放到 __init__ ?
        global bpe  # 为什么用global,设置成成员变量不行吗?
        bpe = BertTokenizer(self.vocab_bpe)

    def get_oov(self, line):
        global bpe
        oov_tokens = []
        for token in bpe.basic_tokenizer.tokenize(line, never_split=bpe.all_special_tokens):
            for sub_token in bpe.wordpiece_tokenizer.tokenize(token):
                if sub_token == '[UNK]':
                    oov_tokens.append(token)
        if len(oov_tokens) > 0:  # 不用在这里打印,因为有些明显需要添加的token
            print(','.join(oov_tokens) + '\t' + line)
        return oov_tokens

    def encode(self, line):
        global bpe
        ids = bpe.encode(line)
        return list(map(str, ids))

    def decode(self, tokens):
        global bpe
        return bpe.decode(tokens)

    def get_oov_lines(self, lines):
        """
        Encode a set of lines. All lines will be encoded together.
        """
        all_oov = []
        for line in lines:
            line = line.strip()
            oov_tokens = self.get_oov(line)
            all_oov += oov_tokens
        return all_oov

    def encode_lines(self, lines):
        """
        Encode a set of lines. All lines will be encoded together.
        """
        enc_lines = []
        for line in lines:
            line = line.strip()
            if len(line) == 0 and not self.args.keep_empty:
                return ["EMPTY", None]
            tokens = self.encode(line)
            enc_lines.append(" ".join(tokens))
        return ["PASS", enc_lines]


def test():
    encoder = MultiprocessingEncoder('vocab.clue_plus.txt')
    encoder.initializer()
    line = '蔲驰的,africa❸ 11111111111165000mg❗2⃣piqueddasdasddasdasda,明天25℃,面积120㎡,大约2~3米' \
           '3200×1800分辨率,TAS海关密码锁,PC镜片,采用A+节能能,胶包裏,包裹,薄至6㎜,鬼塚虎,' \
           '多种矿物元素,特别是锶,靚眼,门闩和便携把手,箜篌刺绣,5㎝,锐蝮蛇竞技版鼠标,滑屛式,T桖,sub+dvi,' \
           '呵护牙齦,Baumatic™ ,'
    en = encoder.encode(line)
    print(line)
    print(en)
    print(encoder.decode(en))

if __name__ == "__main__":
    #main()
    test()