from typing import List, Optional, Union from transformers import PreTrainedTokenizerFast from tokenizers.processors import TemplateProcessing from tokenizers import Tokenizer from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy from transformers.utils import PaddingStrategy, TensorType import torch def create_tokenizer_custom(file): with open(file, 'r') as f: return Tokenizer.from_str(f.read()) class iPLMTokenizer(PreTrainedTokenizerFast): def __init__(self, n_queries, use_structure=True, parallel=False, **kwargs): super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs) self.add_special_tokens({'pad_token': '<|pad|>'}) self.use_structure = use_structure self.n_queries = n_queries if use_structure else 0 self.parallel = parallel def __call__( self, text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None, text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, text_pair_target: Optional[ Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] ] = None, add_special_tokens: bool = True, padding: Union[bool, str, PaddingStrategy] = False, truncation: Union[bool, str, TruncationStrategy] = None, max_length: Optional[int] = None, stride: int = 0, is_split_into_words: bool = False, pad_to_multiple_of: Optional[int] = None, return_tensors: Optional[Union[str, TensorType]] = None, return_token_type_ids: Optional[bool] = None, return_attention_mask: Optional[bool] = None, return_overflowing_tokens: bool = False, return_special_tokens_mask: bool = False, return_offsets_mapping: bool = False, return_length: bool = False, verbose: bool = True, **kwargs, ) -> BatchEncoding: raw_text = [] if not isinstance(text, list): text = [text] if self.use_structure: attn_mask_prefix = torch.zeros((len(text), self.n_queries), dtype=bool) input_ids_prefix = torch.zeros((len(text), self.n_queries), dtype=int) for i in range(len(text)): if '|' in text[i]: res = text[i].split('|') raw_text.append(res[1]) if self.use_structure: # covert and pad structure id to ascii structure_id = torch.tensor([ord(c) for c in res[0]]) input_ids_prefix[i, :len(structure_id)] = structure_id attn_mask_prefix[i] = True else: raw_text.append(text[i]) batch = super().__call__(raw_text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs) if self.use_structure: batch['attention_mask'] = torch.cat([attn_mask_prefix, batch['attention_mask']], dim=1) batch['input_ids'] = torch.cat([input_ids_prefix, batch['input_ids']], dim=1) if "token_type_ids" in batch: del batch["token_type_ids"] return batch