|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List |
|
|
|
|
|
class Trie(object): |
|
def __init__(self, sequences: List[List[int]] = []): |
|
self.trie_dict = {} |
|
self.len = 0 |
|
if sequences: |
|
for sequence in sequences: |
|
Trie._add_to_trie(sequence, self.trie_dict) |
|
self.len += 1 |
|
|
|
self.append_trie = None |
|
self.bos_token_id = None |
|
|
|
def append(self, trie, bos_token_id): |
|
self.append_trie = trie |
|
self.bos_token_id = bos_token_id |
|
|
|
def add(self, sequence: List[int]): |
|
Trie._add_to_trie(sequence, self.trie_dict) |
|
self.len += 1 |
|
|
|
def get(self, prefix_sequence: List[int]): |
|
return Trie._get_from_trie( |
|
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id |
|
) |
|
|
|
@staticmethod |
|
def load_from_dict(trie_dict): |
|
trie = Trie() |
|
trie.trie_dict = trie_dict |
|
trie.len = sum(1 for _ in trie) |
|
return trie |
|
|
|
@staticmethod |
|
def _add_to_trie(sequence: List[int], trie_dict: Dict): |
|
if sequence: |
|
if sequence[0] not in trie_dict: |
|
trie_dict[sequence[0]] = {} |
|
Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]]) |
|
|
|
@staticmethod |
|
def _get_from_trie( |
|
prefix_sequence: List[int], |
|
trie_dict: Dict, |
|
append_trie=None, |
|
bos_token_id: int = None, |
|
): |
|
if len(prefix_sequence) == 0: |
|
output = list(trie_dict.keys()) |
|
if append_trie and bos_token_id in output: |
|
output.remove(bos_token_id) |
|
output += list(append_trie.trie_dict.keys()) |
|
return output |
|
elif prefix_sequence[0] in trie_dict: |
|
return Trie._get_from_trie( |
|
prefix_sequence[1:], |
|
trie_dict[prefix_sequence[0]], |
|
append_trie, |
|
bos_token_id, |
|
) |
|
else: |
|
if append_trie: |
|
return append_trie.get(prefix_sequence) |
|
else: |
|
return [] |
|
|
|
def __iter__(self): |
|
def _traverse(prefix_sequence, trie_dict): |
|
if trie_dict: |
|
for next_token in trie_dict: |
|
yield from _traverse( |
|
prefix_sequence + [next_token], trie_dict[next_token] |
|
) |
|
else: |
|
yield prefix_sequence |
|
|
|
return _traverse([], self.trie_dict) |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __getitem__(self, value): |
|
return self.get(value) |