from typing import Dict, Sequence import numpy as np import torch from torch.nn.utils.rnn import pad_sequence from xtuner.parallel.sequence import (get_sequence_parallel_world_size, pad_for_sequence_parallel) from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX def video_lisa_collate_fn(instances: Sequence[Dict], pad_index: int = DEFAULT_PAD_TOKEN_INDEX, return_hf_format: bool = False, use_varlen_attn: bool = False): seq_parallel_world_size = get_sequence_parallel_world_size() input_ids, labels = [], [] has_image = any(inst.get('pixel_values') is not None for inst in instances) has_pe = any(inst.get('image_grid_thw', None) is not None for inst in instances) has_fast_image = any(inst.get('fast_pixel_values', None) is not None for inst in instances) has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances) has_mask = any(inst.get('masks') is not None for inst in instances) has_bboxes = any(inst.get('bboxes') is not None for inst in instances) has_points = any(inst.get('points') is not None for inst in instances) has_fast_exists = any(inst.get('fast_exists') is not None for inst in instances) has_vp = any(inst.get('vp_overall_mask') is not None for inst in instances) has_prompt_mask = any(inst.get('prompt_masks') is not None for inst in instances) if use_varlen_attn: position_ids, cumulative_len = [], [] assert len(instances) == 1, ( f'If utilizing varlen attention, the batch size should be' f' set to 1, but got {len(instances)}') assert not has_image, 'Currently, it is not configured to ' 'accommodate the use of varlen Attention in multimodal training' if has_image: pixel_values = [] frames_per_batch = [] image_grid_thw = [] if has_grounding_image: grounding_pixel_values = [] if has_mask: object_masks = [] if has_bboxes: object_bboxes = [] if has_points: prompt_points = [] if has_fast_image: fast_pixel_values = [] if has_fast_exists: fast_exists = [] if has_vp: vp_overall_mask = [] else: vp_overall_mask = None if has_prompt_mask: prompt_masks = [] else: prompt_masks = None for example in instances: input_ids.append(torch.LongTensor(example['input_ids'])) labels.append(torch.LongTensor(example['labels'])) if use_varlen_attn: cumulative_len.append(torch.IntTensor(example['cumulative_len'])) position_ids.append(torch.LongTensor(example['position_ids'])) if has_image: pixel_values.append(example['pixel_values']) if has_pe: image_grid_thw.append(example['image_grid_thw']) if has_vp: if 'vp_overall_mask' in example.keys() and example['vp_overall_mask'] is not None: vp_overall_mask.append(example['vp_overall_mask']) else: vp_overall_mask.append(torch.Tensor([False] * len(pixel_values[-1]))) if has_fast_image: if 'fast_pixel_values' in example.keys() and example['fast_pixel_values'] is not None: fast_pixel_values.append(example['fast_pixel_values']) if has_fast_exists: if 'fast_exists' in example.keys() and example['fast_exists'] is not None: fast_exists.append(example['fast_exists']) if has_grounding_image and 'g_pixel_values' in example.keys(): if isinstance(example['g_pixel_values'], list): grounding_pixel_values += example['g_pixel_values'] frames_per_batch.append(len(example['g_pixel_values'])) else: grounding_pixel_values.append(example['g_pixel_values']) frames_per_batch.append(1) if has_mask: if 'masks' in example.keys() and example['masks'] is not None: if isinstance(example['masks'], list): if isinstance(example['masks'][0], np.ndarray): _masks = np.stack(example['masks'], axis=0) _masks = torch.from_numpy(_masks) object_masks.append(_masks) else: object_masks.append(torch.stack(example['masks'], dim=0)) else: object_masks.append(example['masks']) if has_bboxes: if 'bboxes' in example.keys() and example['bboxes'] is not None: object_bboxes.append(example['bboxes']) if has_points: if 'points' in example.keys() and example['points'] is not None: prompt_points.append(example['points']) if has_prompt_mask: if 'prompt_masks' in example.keys(): prompt_masks.append(example['prompt_masks']) ori_length = [len(ids) for ids in input_ids] if len(instances) > 1: input_ids = pad_sequence( input_ids, batch_first=True, padding_value=pad_index) labels = pad_sequence( labels, batch_first=True, padding_value=IGNORE_INDEX) else: input_ids = torch.stack(input_ids) labels = torch.stack(labels) if use_varlen_attn: assert input_ids.size(1) % seq_parallel_world_size == 0 attention_mask = None position_ids = torch.stack(position_ids, dim=0) else: # Some tokenizers have the same eos token and pad token, so input_ids # cannot be masked directly based on the pad token id. attention_mask = torch.zeros_like(input_ids).bool() for i, length in enumerate(ori_length): attention_mask[i, :length] = True bs, seq_len = input_ids.shape position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1) if seq_parallel_world_size > 1: input_ids = pad_for_sequence_parallel(input_ids, pad_index) labels = pad_for_sequence_parallel(labels, IGNORE_INDEX) position_ids = pad_for_sequence_parallel(position_ids, 0) if attention_mask is not None: attention_mask = pad_for_sequence_parallel(attention_mask, 0) if use_varlen_attn: max_seqlen = ( cumulative_len[0][1:] - # noqa: W504 cumulative_len[0][:-1]).max().item() data_dict = { 'input_ids': input_ids, 'cumulative_len': cumulative_len, 'position_ids': position_ids, 'labels': labels, 'max_seqlen': max_seqlen } else: data_dict = { 'input_ids': input_ids, 'attention_mask': attention_mask, 'position_ids': position_ids, 'labels': labels } if has_image: if all(x.shape == pixel_values[0].shape for x in pixel_values): pixel_values = torch.stack(pixel_values, dim=0) data_dict['frames_per_batch'] = frames_per_batch data_dict['pixel_values'] = pixel_values if has_pe: data_dict['image_grid_thw'] = image_grid_thw if has_fast_image: if all(x.shape == fast_pixel_values[0].shape for x in fast_pixel_values): fast_pixel_values = torch.stack(fast_pixel_values, dim=0) data_dict['fast_pixel_values'] = fast_pixel_values if has_fast_exists: data_dict['fast_exists'] = fast_exists if has_vp: data_dict['vp_overall_mask'] = torch.cat(vp_overall_mask, dim=0) if has_prompt_mask: data_dict['prompt_masks'] = prompt_masks if has_grounding_image: # if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values): # grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0) data_dict['g_pixel_values'] = grounding_pixel_values if has_mask: data_dict['masks'] = object_masks if has_bboxes: data_dict['bboxes'] = object_bboxes if has_points: data_dict['points'] = prompt_points if return_hf_format: return data_dict else: return {'data': data_dict, 'data_samples': None}