fffiloni's picture
Migrated from GitHub
d59f323 verified
raw
history blame
5.13 kB
import copy
from xtuner.dataset.utils import get_bos_eos_token_ids
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
def video_lisa_encode_fn(
example,
tokenizer,
max_length,
input_ids_with_output=True,
**kwargs
):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
is_multi_turn_conversation = len(example['conversation']) > 1
if is_multi_turn_conversation:
assert input_ids_with_output
input_ids, labels = [], []
next_needs_bos_token = True
for single_turn_conversation in example['conversation']:
input = single_turn_conversation['input']
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
labels += [IGNORE_INDEX] * len(bos_token_id)
input_ids += input_encode
labels += [IGNORE_INDEX] * len(input_encode)
if input_ids_with_output:
# Add output
output_with_loss = single_turn_conversation.get(
'output_with_loss', True)
output = single_turn_conversation['output']
output_encode = tokenizer.encode(output, add_special_tokens=False)
input_ids += output_encode
if output_with_loss:
labels += copy.deepcopy(output_encode)
else:
labels += [IGNORE_INDEX] * len(output_encode)
# Add EOS_TOKEN (with loss)
if single_turn_conversation.get('need_eos_token', True):
next_needs_bos_token = True
input_ids += eos_token_id
if output_with_loss:
labels += copy.deepcopy(eos_token_id)
else:
labels += [IGNORE_INDEX] * len(eos_token_id)
else:
next_needs_bos_token = False
# Add SEP (without loss)
sep = single_turn_conversation.get('sep', '')
if sep != '':
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
input_ids += sep_encode
labels += [IGNORE_INDEX] * len(sep_encode)
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
labels = labels[:max_length]
return {'input_ids': input_ids, 'labels': labels}
def video_lisa_encode_multi_conv_fn(
example,
tokenizer,
max_length,
input_ids_with_output=True
):
"""We only support the following three scenarios:
1. Incremental pretraining dataset.
example['conversation'] = [
{
'input': '',
'output': '### Human: Can you write xxx'
}
]
2. Single-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
}
]
3. Multi-turn conversation dataset.
example['conversation'] = [
{
'input': 'Give three tips for staying healthy.',
'output': '1.Eat a balanced diet xxx'
},
{
'input': 'Please expand on the second point.',
'output': 'Here is an expanded explanation of the xxx'
}
]
"""
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
assert not input_ids_with_output
input_id_list = []
for conv in example['conversation']:
input_ids = []
next_needs_bos_token = True
for single_turn_conversation in conv:
input = single_turn_conversation['input']
input_encode = tokenizer.encode(input, add_special_tokens=False)
if next_needs_bos_token:
input_ids += bos_token_id
input_ids += input_encode
if len(input_ids) > max_length:
input_ids = input_ids[:max_length]
input_id_list.append(input_ids)
return {'input_ids': input_id_list}