Fix batch input
Browse files- tokenization_chatglm.py +3 -3
tokenization_chatglm.py
CHANGED
@@ -177,7 +177,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
177 |
|
178 |
vocab_files_names = {"vocab_file": "ice_text.model"}
|
179 |
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
180 |
-
model_input_names = ["input_ids"]
|
181 |
|
182 |
def __init__(
|
183 |
self,
|
@@ -397,7 +397,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
397 |
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
398 |
|
399 |
# Initialize attention mask if not present.
|
400 |
-
if
|
401 |
context_length = required_input.index(bos_token_id)
|
402 |
attention_mask = np.ones((1, seq_length, seq_length))
|
403 |
attention_mask = np.tril(attention_mask)
|
@@ -405,7 +405,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
405 |
attention_mask = np.bool_(attention_mask < 0.5)
|
406 |
encoded_inputs["attention_mask"] = attention_mask
|
407 |
|
408 |
-
if
|
409 |
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
410 |
mask_position = required_input.index(mask_token)
|
411 |
context_length = required_input.index(bos_token_id)
|
|
|
177 |
|
178 |
vocab_files_names = {"vocab_file": "ice_text.model"}
|
179 |
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
180 |
+
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
181 |
|
182 |
def __init__(
|
183 |
self,
|
|
|
397 |
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
398 |
|
399 |
# Initialize attention mask if not present.
|
400 |
+
if return_attention_mask:
|
401 |
context_length = required_input.index(bos_token_id)
|
402 |
attention_mask = np.ones((1, seq_length, seq_length))
|
403 |
attention_mask = np.tril(attention_mask)
|
|
|
405 |
attention_mask = np.bool_(attention_mask < 0.5)
|
406 |
encoded_inputs["attention_mask"] = attention_mask
|
407 |
|
408 |
+
if return_attention_mask:
|
409 |
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
410 |
mask_position = required_input.index(mask_token)
|
411 |
context_length = required_input.index(bos_token_id)
|