Files changed (1) hide show
  1. tokenizer.py +72 -73
tokenizer.py CHANGED
@@ -1,89 +1,88 @@
1
  import torch
2
  import numpy as np
3
- from transformers import RobertaTokenizer, BatchEncoding
4
  import warnings
5
 
6
 
7
- class JinaTokenizer(RobertaTokenizer):
8
- def __init__(self, *args, **kwargs):
9
- """
10
- JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in
11
- the batch encoding.
12
- The task_type_ids are used to pass instruction information to the model.
13
- A task_type should either be an integer or a sequence of integers with the same
14
- length as the batch size.
15
- """
16
- super().__init__(*args, **kwargs)
 
17
 
18
- def __call__(self, *args, task_type=None, **kwargs):
19
- batch_encoding = super().__call__(*args, **kwargs)
20
- if task_type is not None:
21
- batch_encoding = BatchEncoding(
22
- {
23
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
24
- **batch_encoding,
25
- },
26
- tensor_type=kwargs.get('return_tensors'),
27
- )
28
- return batch_encoding
29
 
30
- def _batch_encode_plus(self, *args, task_type=None, **kwargs):
31
- batch_encoding = super()._batch_encode_plus(*args, **kwargs)
32
- if task_type is not None:
33
- batch_encoding = BatchEncoding(
34
- {
35
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
36
- **batch_encoding,
37
- },
38
- tensor_type=kwargs.get('return_tensors'),
39
- )
40
- return batch_encoding
41
 
42
- def _encode_plus(self, *args, task_type=None, **kwargs):
43
- batch_encoding = super()._encode_plus(*args, **kwargs)
44
- if task_type is not None:
45
- batch_encoding = BatchEncoding(
46
  {
47
- 'task_type_ids': self._get_task_type_ids(batch_encoding, task_type),
48
  **batch_encoding,
49
  },
50
- tensor_type=kwargs.get('return_tensors'),
51
  )
52
- return batch_encoding
53
 
54
- @staticmethod
55
- def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
56
 
57
- def apply_task_type(m, x):
58
- x = torch.tensor(x)
59
- assert (
60
- len(x.shape) == 0 or x.shape[0] == m.shape[0]
61
- ), 'The shape of task_type does not match the size of the batch.'
62
- return m * x if len(x.shape) == 0 else m * x[:, None]
63
 
64
- if isinstance(batch_encoding['input_ids'], torch.Tensor):
65
- shape = batch_encoding['input_ids'].shape
66
- return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
67
- else:
68
- try:
69
- shape = torch.tensor(batch_encoding['input_ids']).shape
70
- except:
71
- raise ValueError(
72
- "Unable to create tensor, you should probably "
73
- "activate truncation and/or padding with "
74
- "'padding=True' 'truncation=True' to have batched "
75
- "tensors with the same length."
76
- )
77
- if isinstance(batch_encoding['input_ids'], list):
78
- return (
79
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
80
- ).tolist()
81
- elif isinstance(batch_encoding['input_ids'], np.array):
82
- return (
83
- apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
84
- ).numpy()
85
- else:
86
- warnings.warn(
87
- 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
88
- )
89
  return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import numpy as np
3
+ from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast
4
  import warnings
5
 
6
 
7
+ def get_tokenizer(parent_class):
8
+ class TokenizerClass(parent_class):
9
+ def __init__(self, *args, **kwargs):
10
+ """
11
+ This class dynamically extends a given tokenizer class from the HF
12
+ Transformers library (RobertaTokenizer or RobertaTokenizerFast).
13
+ The task_type_ids are used to pass instruction information to the model.
14
+ A task_type should either be an integer or a sequence of integers with the same
15
+ length as the batch size.
16
+ """
17
+ super().__init__(*args, **kwargs)
18
 
19
+ def __call__(self, *args, task_type=None, **kwargs):
20
+ batch_encoding = super().__call__(*args, **kwargs)
21
+ if task_type is not None:
22
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
23
+ return batch_encoding
 
 
 
 
 
 
24
 
25
+ def _batch_encode_plus(self, *args, task_type=None, **kwargs):
26
+ batch_encoding = super()._batch_encode_plus(*args, **kwargs)
27
+ if task_type is not None:
28
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
29
+ return batch_encoding
30
+
31
+ def _encode_plus(self, *args, task_type=None, **kwargs):
32
+ batch_encoding = super()._encode_plus(*args, **kwargs)
33
+ if task_type is not None:
34
+ batch_encoding = self._add_task_type_ids(batch_encoding, task_type, kwargs.get('return_tensors'))
35
+ return batch_encoding
36
 
37
+ @classmethod
38
+ def _add_task_type_ids(cls, batch_encoding, task_type, tensor_type):
39
+ return BatchEncoding(
 
40
  {
41
+ 'task_type_ids': cls._get_task_type_ids(batch_encoding, task_type),
42
  **batch_encoding,
43
  },
44
+ tensor_type=tensor_type,
45
  )
 
46
 
47
+ @staticmethod
48
+ def _get_task_type_ids(batch_encoding: BatchEncoding, task_type):
49
 
50
+ def apply_task_type(m, x):
51
+ x = torch.tensor(x)
52
+ assert (
53
+ len(x.shape) == 0 or x.shape[0] == m.shape[0]
54
+ ), 'The shape of task_type does not match the size of the batch.'
55
+ return m * x if len(x.shape) == 0 else m * x[:, None]
56
 
57
+ if isinstance(batch_encoding['input_ids'], torch.Tensor):
58
+ shape = batch_encoding['input_ids'].shape
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
60
+ else:
61
+ try:
62
+ shape = torch.tensor(batch_encoding['input_ids']).shape
63
+ except:
64
+ raise ValueError(
65
+ "Unable to create tensor, you should probably "
66
+ "activate truncation and/or padding with "
67
+ "'padding=True' 'truncation=True' to have batched "
68
+ "tensors with the same length."
69
+ )
70
+ if isinstance(batch_encoding['input_ids'], list):
71
+ return (
72
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
73
+ ).tolist()
74
+ elif isinstance(batch_encoding['input_ids'], np.array):
75
+ return (
76
+ apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
77
+ ).numpy()
78
+ else:
79
+ warnings.warn(
80
+ 'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor'
81
+ )
82
+ return apply_task_type(torch.ones(shape, dtype=torch.long), task_type)
83
+
84
+ return TokenizerClass
85
+
86
+
87
+ JinaTokenizer = get_tokenizer(RobertaTokenizer)
88
+ JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast)