|
import torch |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
class CTCLabelConverter(object): |
|
""" Convert between text-label and text-index """ |
|
|
|
def __init__(self, character): |
|
|
|
dict_character = list(character) |
|
|
|
self.dict = {} |
|
for i, char in enumerate(dict_character): |
|
|
|
self.dict[char] = i + 1 |
|
|
|
self.character = ['[CTCblank]'] + dict_character |
|
|
|
def encode(self, text, batch_max_length=25): |
|
"""convert text-label into text-index. |
|
input: |
|
text: text labels of each image. [batch_size] |
|
batch_max_length: max length of text label in the batch. 25 by default |
|
|
|
output: |
|
text: text index for CTCLoss. [batch_size, batch_max_length] |
|
length: length of each text. [batch_size] |
|
""" |
|
length = [len(s) for s in text] |
|
|
|
|
|
batch_text = torch.LongTensor(len(text), batch_max_length).fill_(0) |
|
for i, t in enumerate(text): |
|
text = list(t) |
|
try: |
|
text = [self.dict[char] for char in text] |
|
except Exception as e: |
|
print(text) |
|
raise e |
|
batch_text[i][:len(text)] = torch.LongTensor(text) |
|
return (batch_text.to(device), torch.IntTensor(length).to(device)) |
|
|
|
def decode(self, text_index, length): |
|
""" convert text-index into text-label. """ |
|
texts = [] |
|
for index, l in enumerate(length): |
|
t = text_index[index, :] |
|
|
|
char_list = [] |
|
for i in range(l): |
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
|
char_list.append(self.character[t[i]]) |
|
text = ''.join(char_list) |
|
|
|
texts.append(text) |
|
return texts |
|
|
|
|
|
class CTCLabelConverterForBaiduWarpctc(object): |
|
""" Convert between text-label and text-index for baidu warpctc """ |
|
|
|
def __init__(self, character): |
|
|
|
dict_character = list(character) |
|
|
|
self.dict = {} |
|
for i, char in enumerate(dict_character): |
|
|
|
self.dict[char] = i + 1 |
|
|
|
self.character = ['[CTCblank]'] + dict_character |
|
|
|
def encode(self, text, batch_max_length=25): |
|
"""convert text-label into text-index. |
|
input: |
|
text: text labels of each image. [batch_size] |
|
output: |
|
text: concatenated text index for CTCLoss. |
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] |
|
length: length of each text. [batch_size] |
|
""" |
|
length = [len(s) for s in text] |
|
text = ''.join(text) |
|
text = [self.dict[char] for char in text] |
|
|
|
return (torch.IntTensor(text), torch.IntTensor(length)) |
|
|
|
def decode(self, text_index, length): |
|
""" convert text-index into text-label. """ |
|
texts = [] |
|
index = 0 |
|
for l in length: |
|
t = text_index[index:index + l] |
|
|
|
char_list = [] |
|
for i in range(l): |
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])): |
|
char_list.append(self.character[t[i]]) |
|
text = ''.join(char_list) |
|
|
|
texts.append(text) |
|
index += l |
|
return texts |
|
|
|
|
|
class AttnLabelConverter(object): |
|
""" Convert between text-label and text-index """ |
|
|
|
def __init__(self, character): |
|
|
|
|
|
list_token = ['[GO]', '[s]'] |
|
list_character = list(character) |
|
self.character = list_token + list_character |
|
|
|
self.dict = {} |
|
for i, char in enumerate(self.character): |
|
|
|
self.dict[char] = i |
|
|
|
def encode(self, text, batch_max_length=25): |
|
""" convert text-label into text-index. |
|
input: |
|
text: text labels of each image. [batch_size] |
|
batch_max_length: max length of text label in the batch. 25 by default |
|
|
|
output: |
|
text : the input of attention decoder. [batch_size x (max_length+2)] +1 for [GO] token and +1 for [s] token. |
|
text[:, 0] is [GO] token and text is padded with [GO] token after [s] token. |
|
length : the length of output of attention decoder, which count [s] token also. [3, 7, ....] [batch_size] |
|
""" |
|
length = [len(s) + 1 for s in text] |
|
|
|
batch_max_length += 1 |
|
|
|
batch_text = torch.LongTensor(len(text), batch_max_length + 1).fill_(0) |
|
for i, t in enumerate(text): |
|
text = list(t) |
|
text.append('[s]') |
|
text = [self.dict[char] for char in text] |
|
batch_text[i][1:1 + len(text)] = torch.LongTensor(text) |
|
return (batch_text.to(device), torch.IntTensor(length).to(device)) |
|
|
|
def decode(self, text_index, length): |
|
""" convert text-index into text-label. """ |
|
texts = [] |
|
for index, l in enumerate(length): |
|
text = ''.join([self.character[i] for i in text_index[index, :]]) |
|
texts.append(text) |
|
return texts |
|
|
|
|
|
class Averager(object): |
|
"""Compute average for torch.Tensor, used for loss average.""" |
|
|
|
def __init__(self): |
|
self.reset() |
|
|
|
def add(self, v): |
|
count = v.data.numel() |
|
v = v.data.sum() |
|
self.n_count += count |
|
self.sum += v |
|
|
|
def reset(self): |
|
self.n_count = 0 |
|
self.sum = 0 |
|
|
|
def val(self): |
|
res = 0 |
|
if self.n_count != 0: |
|
res = self.sum / float(self.n_count) |
|
return res |
|
|