File size: 3,821 Bytes
c525dff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from typing import List, Optional, Union
from transformers import PreTrainedTokenizerFast
from tokenizers.processors import TemplateProcessing
from tokenizers import Tokenizer
from transformers.tokenization_utils_base import BatchEncoding, EncodedInput, PreTokenizedInput, TextInput, TruncationStrategy
from transformers.utils import PaddingStrategy, TensorType
import torch

def create_tokenizer_custom(file):
    with open(file, 'r') as f:
        return Tokenizer.from_str(f.read())
    

class iPLMTokenizer(PreTrainedTokenizerFast):
    def __init__(self, n_queries, use_structure=True, parallel=False, **kwargs):
        super().__init__(tokenizer_object=create_tokenizer_custom(kwargs.get('tokenizer_file')), **kwargs)
        self.add_special_tokens({'pad_token': '<|pad|>'})
        self.use_structure = use_structure
        self.n_queries = n_queries if use_structure else 0
        self.parallel = parallel
    def __call__(
            self,
            text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
            text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
            text_target: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
            text_pair_target: Optional[
                Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]
            ] = None,
            add_special_tokens: bool = True,
            padding: Union[bool, str, PaddingStrategy] = False,
            truncation: Union[bool, str, TruncationStrategy] = None,
            max_length: Optional[int] = None,
            stride: int = 0,
            is_split_into_words: bool = False,
            pad_to_multiple_of: Optional[int] = None,
            return_tensors: Optional[Union[str, TensorType]] = None,
            return_token_type_ids: Optional[bool] = None,
            return_attention_mask: Optional[bool] = None,
            return_overflowing_tokens: bool = False,
            return_special_tokens_mask: bool = False,
            return_offsets_mapping: bool = False,
            return_length: bool = False,
            verbose: bool = True,
            **kwargs,
        ) -> BatchEncoding:
        
        raw_text = []

        if not isinstance(text, list):
            text = [text]
        
        if self.use_structure:
            attn_mask_prefix = torch.zeros((len(text), self.n_queries), dtype=bool)
            input_ids_prefix = torch.zeros((len(text), self.n_queries), dtype=int)
            
        for i in range(len(text)):
            if '|' in text[i]:

                res = text[i].split('|')
                raw_text.append(res[1])
                
                if self.use_structure:
                    # covert and pad structure id to ascii
                    structure_id = torch.tensor([ord(c) for c in res[0]])
                    input_ids_prefix[i, :len(structure_id)] = structure_id
                    
                    attn_mask_prefix[i] = True
            else:
                raw_text.append(text)

        batch = super().__call__(raw_text, text_pair, text_target, text_pair_target, add_special_tokens, padding, truncation, max_length, stride, is_split_into_words, pad_to_multiple_of, return_tensors, return_token_type_ids, return_attention_mask, return_overflowing_tokens, return_special_tokens_mask, return_offsets_mapping, return_length, verbose, **kwargs)
        
        if self.use_structure:
            batch['attention_mask'] = torch.cat([attn_mask_prefix, batch['attention_mask']], dim=1)
            batch['input_ids'] = torch.cat([input_ids_prefix, batch['input_ids']], dim=1)
        
        if "token_type_ids" in batch:
            del batch["token_type_ids"]

        return batch