|
import logging |
|
import warnings |
|
from typing import Any, Dict, List, Optional, Union |
|
import torch |
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast |
|
log = logging.getLogger(__name__) |
|
_HF_IGNORE_INDEX = -100 |
|
TokenizedExample = Dict[str, List[Dict[str, List[int]]]] |
|
|
|
def ensure_list(x: Union[List, torch.Tensor]) -> List: |
|
if isinstance(x, torch.Tensor): |
|
x = list(x.flatten()) |
|
assert isinstance(x, list) |
|
return x |
|
|
|
def validate_target_settings(target_prompts: str, target_responses: str, decoder_only_format: bool): |
|
"""Raises an error if target settings are invalid.""" |
|
if not decoder_only_format and (target_prompts != 'none' or target_responses != 'last'): |
|
raise ValueError(f'When using encoder_decoder format, you must use target_prompts="none" and target_responses="last".') |
|
if target_responses not in {'all', 'last'}: |
|
raise ValueError(f'target_responses must be either "last" or "all" but target_responses={target_responses!r}') |
|
if target_prompts.startswith('length>='): |
|
cutoff = target_prompts[8:] |
|
if not cutoff.isdigit(): |
|
raise ValueError(f'target_prompts starts with "length>=" but the rest of the string is not digits (target_prompts={target_prompts!r}). ' + 'To use this configuration option, set target_prompts "length>=XX" where "XX" is a positive integer indicating ' + 'the length cutoff. Prompts of at least XX tokens in length will be treated as targets.') |
|
cutoff = int(cutoff) |
|
if cutoff <= 0: |
|
raise ValueError(f'You are trying to set the target_prompts length cutoff to a negative number cutoff={cutoff!r}. This is not allowed.') |
|
elif target_prompts not in {'all', 'none'}: |
|
raise ValueError(f'target_prompts must either be "all", "none" or "length>=XX" where "XX" is a positive integer, but target_prompts={target_prompts!r}') |
|
|
|
def _sequence_to_labels_all(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]: |
|
del is_last_turn, cutoff |
|
return sequence |
|
|
|
def _sequence_to_labels_none(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]: |
|
del is_last_turn, cutoff |
|
return [_HF_IGNORE_INDEX] * len(sequence) |
|
|
|
def _sequence_to_labels_last(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]: |
|
del cutoff |
|
if is_last_turn: |
|
return sequence |
|
else: |
|
return [_HF_IGNORE_INDEX] * len(sequence) |
|
|
|
def _sequence_to_labels_cutoff(sequence: list[int], is_last_turn: bool, cutoff: Optional[int]=None) -> list[int]: |
|
del is_last_turn |
|
if cutoff is None: |
|
raise ValueError('input ``cutoff`` must be provided') |
|
if len(sequence) >= cutoff: |
|
return sequence |
|
else: |
|
return [_HF_IGNORE_INDEX] * len(sequence) |
|
_TARGET_POLICY_LOOKUP = {'all': _sequence_to_labels_all, 'none': _sequence_to_labels_none, 'last': _sequence_to_labels_last, 'length': _sequence_to_labels_cutoff} |
|
|
|
def stitch_turns_decoder_only(example_turns: list[dict[str, list[int]]], target_prompts: str, target_responses: str, eos_token_id: Optional[int]=None, validate: bool=False) -> tuple[list[int], list[int]]: |
|
target_prompts = target_prompts.lower() |
|
target_responses = target_responses.lower() |
|
if validate: |
|
validate_target_settings(target_prompts, target_responses, decoder_only_format=True) |
|
if target_prompts.startswith('length'): |
|
prompt_cutoff = int(target_prompts.split('>=')[-1]) |
|
prompt_to_target = _TARGET_POLICY_LOOKUP['length'] |
|
else: |
|
prompt_cutoff = None |
|
prompt_to_target = _TARGET_POLICY_LOOKUP[target_prompts] |
|
response_to_target = _TARGET_POLICY_LOOKUP[target_responses] |
|
input_ids = [] |
|
labels = [] |
|
for idx, turn in enumerate(example_turns): |
|
is_last_turn = idx + 1 == len(example_turns) |
|
context = ensure_list(turn['input_ids']) |
|
target = ensure_list(turn['labels']) |
|
if is_last_turn and eos_token_id is not None: |
|
if target[-1] != eos_token_id: |
|
target = target + [eos_token_id] |
|
input_ids += context |
|
input_ids += target |
|
labels += prompt_to_target(context, is_last_turn, prompt_cutoff) |
|
labels += response_to_target(target, is_last_turn) |
|
if len(input_ids) != len(labels): |
|
raise ValueError(f'input_ids and labels should be the same length, len(input_ids)={len(input_ids)!r}, len(labels)={len(labels)!r}') |
|
return (input_ids, labels) |
|
|
|
def stitch_turns_encoder_decoder(example_turns: list[dict[str, list[int]]], eos_token_id: Optional[int]=None) -> tuple[list[int], list[int]]: |
|
context = [] |
|
target = None |
|
for idx, turn in enumerate(example_turns): |
|
is_last_turn = idx + 1 == len(example_turns) |
|
turn_context = ensure_list(turn['input_ids']) |
|
turn_target = ensure_list(turn['labels']) |
|
context += turn_context |
|
if is_last_turn: |
|
if eos_token_id is not None and turn_target[-1] != eos_token_id: |
|
turn_target = turn_target + [eos_token_id] |
|
target = turn_target |
|
else: |
|
context += turn_target |
|
if target is None: |
|
raise ValueError('target is still None but should be list[int]') |
|
return (context, target) |
|
|
|
class Seq2SeqFinetuningCollator: |
|
"""A general-purpose collator for sequence-to-sequence training/evaluation. |
|
|
|
Args: |
|
tokenizer: A HuggingFace tokenizer. Must have a pad_token set. |
|
max_seq_len (int): The maximum sequence length of the combined |
|
context/target sequence (decoder-only format) or of each the |
|
context sequence and target sequence (encoder-decoder format). |
|
decoder_only_format (bool): Whether to format the batches for a |
|
decoder-only model (if True) or an encoder-decoder model (if False). |
|
target_responses (str): For multi-turn examples, this controls which |
|
responses are treated as training targets (i.e. generate loss). |
|
Options are: |
|
"last": (Default) Only the final response is used as the training |
|
target; non-terminal responses are only part of the context. |
|
"all": All of the responses are used as training targets. |
|
target_prompts (str): This controls which prompts are treated as |
|
training targets (i.e. generate loss). |
|
Options are: |
|
"none": (Default) Prompts are never used as training targets. |
|
"all": Prompts are always used as training targets. |
|
"length>=XX": Prompt sequences are used as training targets when |
|
they have length of at least XX tokens. For instance, |
|
setting "length>=512" instructs the collator to use a prompt |
|
sequence as a training target when it is at least 512 tokens long. |
|
allow_pad_trimming (bool, optional): Whether to allow the collator |
|
to trim padding, which may result in smaller but inconsistent batch |
|
sizes. Default: ``False`` ensures that all sequences are max_seq_len. |
|
batch_metadata (dict, optional): A dictionary of metadata which will be added |
|
to the batch. |
|
""" |
|
|
|
def __init__(self, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], max_seq_len: int, decoder_only_format: bool, target_responses: str='last', target_prompts: str='none', allow_pad_trimming: bool=False, batch_metadata: Optional[Dict[str, Any]]=None): |
|
self.tokenizer = tokenizer |
|
self.max_seq_len = max_seq_len |
|
self.decoder_only_format = decoder_only_format |
|
self.target_responses = target_responses.lower() |
|
self.target_prompts = target_prompts.lower() |
|
self.batch_metadata = batch_metadata or {} |
|
self._allow_pad_trimming = allow_pad_trimming |
|
self._seen_first_batch = False |
|
illegal_keys = ['input_ids', 'labels', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask'] |
|
found_keys = [] |
|
for illegal_key in illegal_keys: |
|
if illegal_key in self.batch_metadata: |
|
found_keys.append(illegal_key) |
|
if found_keys: |
|
raise ValueError(f"The following keys are in batch_metadata but are not allowed: {', '.join(found_keys)}.\n" + f'You cannot use keys that are used directly by the models. The prohibited keys are:\n' + f"{', '.join(illegal_keys)}") |
|
if max_seq_len % 8 != 0: |
|
log.warning('For performance, a max_seq_len as a multiple of 8 is recommended.') |
|
if self.tokenizer.pad_token_id is None: |
|
raise ValueError(f'{self.__class__.__name__} requires that the tokenizer has the pad token set, but it is None') |
|
validate_target_settings(self.target_prompts, self.target_responses, self.decoder_only_format) |
|
if self.target_prompts.startswith('length'): |
|
self.prompt_cutoff = int(self.target_prompts.split('>=')[-1]) |
|
self.prompt_to_target = _TARGET_POLICY_LOOKUP['length'] |
|
else: |
|
self.prompt_cutoff = None |
|
self.prompt_to_target = _TARGET_POLICY_LOOKUP[self.target_prompts] |
|
self.response_to_target = _TARGET_POLICY_LOOKUP[self.target_responses] |
|
self._warned_truncated = False |
|
self._warned_context = False |
|
self._warned_target = False |
|
|
|
def __call__(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]: |
|
for check_key in ['input_ids', 'labels']: |
|
if check_key not in examples[0]['turns'][0]: |
|
raise KeyError(f'Examples returned by dataset do not include required key: {check_key}') |
|
if self.decoder_only_format: |
|
batch = self._process_and_batch_decoder_only(examples) |
|
else: |
|
batch = self._process_and_batch_encoder_decoder(examples) |
|
batch_size = batch['input_ids'].shape[0] |
|
batch.update({k: torch.tensor([v] * batch_size) for k, v in self.batch_metadata.items()}) |
|
return batch |
|
|
|
def _process_and_batch_decoder_only(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]: |
|
processed_examples = [] |
|
for example in examples: |
|
input_ids, labels = stitch_turns_decoder_only(example_turns=example['turns'], target_prompts=self.target_prompts, target_responses=self.target_responses, eos_token_id=self.tokenizer.eos_token_id) |
|
orig_size = len(input_ids) |
|
if orig_size > self.max_seq_len: |
|
input_ids = input_ids[:self.max_seq_len] |
|
labels = labels[:self.max_seq_len] |
|
if len([l for l in labels if l != _HF_IGNORE_INDEX]) == 0: |
|
raise ValueError(f'Truncating to max_seq_len={self.max_seq_len} has removed all loss-generating tokens. ' + f'Pre-truncation sequence length was {orig_size}. ' + 'This sample should have been filtered out before reaching the collator. If using ' + 'pre-tokenized streaming data, this may have resulted from using different ' + '``target_prompts``, ``target_responses``, or ``max_seq_len`` ' + 'settings when preparing the streaming dataset than what are currently being used.') |
|
if not self._warned_truncated: |
|
warnings.warn(f'Truncating sequence of length={orig_size} to fit max_seq_len={self.max_seq_len}. ' + f'If truncation is a problem, consider increasing max_seq_len.') |
|
self._warned_truncated = True |
|
attention_mask = [1] * len(input_ids) |
|
n_total = len(input_ids) |
|
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - n_total) |
|
if self.tokenizer.padding_side == 'left': |
|
labels = i_pad + labels |
|
else: |
|
labels = labels + i_pad |
|
processed_example = {'input_ids': input_ids, 'labels': labels, 'attention_mask': attention_mask} |
|
processed_examples.append(processed_example) |
|
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt') |
|
batch['sequence_id'] = batch['attention_mask'] - 1 |
|
if not (self._allow_pad_trimming and self._seen_first_batch): |
|
self._seen_first_batch = True |
|
return batch |
|
self._seen_first_batch = True |
|
multiple_of = 8 |
|
n_non_padding = batch['attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k, v in batch.items(): |
|
if len(v.shape) < 2: |
|
continue |
|
if self.tokenizer.padding_side == 'left': |
|
batch[k] = v[:, -keep_tokens:].contiguous() |
|
else: |
|
batch[k] = v[:, :keep_tokens].contiguous() |
|
return batch |
|
|
|
def _process_and_batch_encoder_decoder(self, examples: List[TokenizedExample]) -> Dict[str, torch.Tensor]: |
|
processed_examples = [] |
|
for example in examples: |
|
context, target = stitch_turns_encoder_decoder(example_turns=example['turns'], eos_token_id=self.tokenizer.eos_token_id) |
|
if len(target) < self.max_seq_len: |
|
i_pad = [_HF_IGNORE_INDEX] * (self.max_seq_len - len(target)) |
|
target = target + i_pad |
|
else: |
|
if not self._warned_target: |
|
warnings.warn(f'Truncating TARGET sequence of length={len(target)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.') |
|
self._warned_target = True |
|
target = target[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id] |
|
if len(context) > self.max_seq_len: |
|
if not self._warned_context: |
|
warnings.warn(f'Truncating CONTEXT sequence of length={len(context)} ' + f'to max_seq_len={self.max_seq_len}. If truncation is ' + f'a problem, consider increasing max_seq_len.') |
|
self._warned_context = True |
|
context = context[:self.max_seq_len - 1] + [self.tokenizer.eos_token_id] |
|
processed_example = {'input_ids': context, 'labels': target, 'attention_mask': [1] * len(context)} |
|
processed_examples.append(processed_example) |
|
batch = self.tokenizer.pad(processed_examples, padding='max_length', max_length=self.max_seq_len, return_tensors='pt') |
|
batch['decoder_input_ids'] = torch.cat([torch.full((len(processed_examples), 1), self.tokenizer.pad_token_id), batch['labels'][:, :-1]], dim=1) |
|
batch['decoder_input_ids'].masked_fill_(batch['decoder_input_ids'] == _HF_IGNORE_INDEX, self.tokenizer.pad_token_id) |
|
batch['decoder_attention_mask'] = torch.not_equal(batch['labels'], _HF_IGNORE_INDEX) |
|
if not (self._allow_pad_trimming and self._seen_first_batch): |
|
self._seen_first_batch = True |
|
return batch |
|
self._seen_first_batch = True |
|
multiple_of = 8 |
|
n_non_padding = batch['attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k in ['input_ids', 'attention_mask']: |
|
batch[k] = batch[k][:, :keep_tokens].contiguous() |
|
n_non_padding = batch['decoder_attention_mask'].sum(dim=1).max() |
|
keep_tokens = int(multiple_of * torch.ceil(n_non_padding / multiple_of)) |
|
for k in ['decoder_input_ids', 'decoder_attention_mask', 'labels']: |
|
batch[k] = batch[k][:, :keep_tokens].contiguous() |
|
return batch |