genre-kilt / trie.py
nicoladecao's picture
Create trie.py
7d15402
raw
history blame
2.78 kB
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree at
# https://github.com/facebookresearch/GENRE .
from typing import Dict, List
class Trie(object):
def __init__(self, sequences: List[List[int]] = []):
self.trie_dict = {}
self.len = 0
if sequences:
for sequence in sequences:
Trie._add_to_trie(sequence, self.trie_dict)
self.len += 1
self.append_trie = None
self.bos_token_id = None
def append(self, trie, bos_token_id):
self.append_trie = trie
self.bos_token_id = bos_token_id
def add(self, sequence: List[int]):
Trie._add_to_trie(sequence, self.trie_dict)
self.len += 1
def get(self, prefix_sequence: List[int]):
return Trie._get_from_trie(
prefix_sequence, self.trie_dict, self.append_trie, self.bos_token_id
)
@staticmethod
def load_from_dict(trie_dict):
trie = Trie()
trie.trie_dict = trie_dict
trie.len = sum(1 for _ in trie)
return trie
@staticmethod
def _add_to_trie(sequence: List[int], trie_dict: Dict):
if sequence:
if sequence[0] not in trie_dict:
trie_dict[sequence[0]] = {}
Trie._add_to_trie(sequence[1:], trie_dict[sequence[0]])
@staticmethod
def _get_from_trie(
prefix_sequence: List[int],
trie_dict: Dict,
append_trie=None,
bos_token_id: int = None,
):
if len(prefix_sequence) == 0:
output = list(trie_dict.keys())
if append_trie and bos_token_id in output:
output.remove(bos_token_id)
output += list(append_trie.trie_dict.keys())
return output
elif prefix_sequence[0] in trie_dict:
return Trie._get_from_trie(
prefix_sequence[1:],
trie_dict[prefix_sequence[0]],
append_trie,
bos_token_id,
)
else:
if append_trie:
return append_trie.get(prefix_sequence)
else:
return []
def __iter__(self):
def _traverse(prefix_sequence, trie_dict):
if trie_dict:
for next_token in trie_dict:
yield from _traverse(
prefix_sequence + [next_token], trie_dict[next_token]
)
else:
yield prefix_sequence
return _traverse([], self.trie_dict)
def __len__(self):
return self.len
def __getitem__(self, value):
return self.get(value)