Spaces:
Runtime error
Runtime error
# Copyright 2024 OpenAccess AI Collective and the LlamaFactory team. | |
# | |
# This code is inspired by the OpenAccess AI Collective's axolotl library. | |
# https://github.com/OpenAccess-AI-Collective/axolotl/blob/main/src/axolotl/monkeypatch/utils.py | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from dataclasses import dataclass | |
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence | |
import torch | |
from transformers import DataCollatorForSeq2Seq | |
if TYPE_CHECKING: | |
from transformers import ProcessorMixin | |
from .template import Template | |
def prepare_4d_attention_mask(attention_mask_with_indices: "torch.Tensor", dtype: "torch.dtype") -> "torch.Tensor": | |
r""" | |
Expands the attention mask with indices from (batch_size, seq_len) to (batch_size, 1, seq_len, seq_len), | |
while handles packed sequences and transforms the mask to lower triangular form to prevent future peeking. | |
e.g. | |
```python | |
# input | |
[[1, 1, 2, 2, 2, 0]] | |
# output | |
[ | |
[ | |
[ | |
[o, x, x, x, x, x], | |
[o, o, x, x, x, x], | |
[x, x, o, x, x, x], | |
[x, x, o, o, x, x], | |
[x, x, o, o, o, x], | |
[x, x, x, x, x, x], | |
] | |
] | |
] | |
``` | |
where `o` equals to `0.0`, `x` equals to `min_dtype`. | |
""" | |
bsz, seq_len = attention_mask_with_indices.size() | |
min_dtype = torch.finfo(dtype).min | |
expanded_mask = attention_mask_with_indices[:, None, None, :].expand(bsz, 1, seq_len, seq_len) | |
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one | |
padding_mask = torch.where(expanded_mask != 0, 1, 0) | |
# Create a block-diagonal mask. | |
attention_mask_4d = torch.eq(expanded_mask, expanded_mask.transpose(-1, -2)).int() * padding_mask | |
# Use the lower triangular mask to zero out the upper triangular part | |
attention_mask_4d *= torch.tril(torch.ones((seq_len, seq_len), dtype=torch.long)) | |
# Invert the attention mask. | |
attention_mask_4d = torch.where(attention_mask_4d != 0, torch.tensor(0, dtype=dtype), min_dtype) | |
return attention_mask_4d | |
class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): | |
r""" | |
Data collator that supports VLMs. | |
Features should contain input_ids, attention_mask, labels and images. | |
""" | |
template: Optional["Template"] = None | |
processor: Optional["ProcessorMixin"] = None | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: | |
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens = [], [], [], [], [] | |
for feature in features: | |
images = feature.pop("images", None) or [] | |
videos = feature.pop("videos", None) or [] | |
batch_images.extend(images) | |
batch_videos.extend(videos) | |
batch_imglens.append(len(images)) | |
batch_vidlens.append(len(videos)) | |
batch_seqlens.append(len(feature["input_ids"])) | |
mm_inputs = self.template.mm_plugin.get_mm_inputs( | |
batch_images, batch_videos, batch_imglens, batch_vidlens, batch_seqlens, self.processor | |
) | |
if "token_type_ids" in mm_inputs: | |
token_type_ids = mm_inputs.pop("token_type_ids") | |
for i, feature in enumerate(features): | |
feature["token_type_ids"] = token_type_ids[i] | |
features: Dict[str, "torch.Tensor"] = super().__call__(features) | |
features.update(mm_inputs) | |
return features | |
class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): | |
r""" | |
Data collator for 4d attention mask. | |
""" | |
block_diag_attn: bool = False | |
attn_implementation: Literal["eager", "sdpa", "flash_attention_2"] = "eager" | |
compute_dtype: "torch.dtype" = torch.float32 | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: | |
features = super().__call__(features) | |
if self.block_diag_attn and self.attn_implementation != "flash_attention_2": | |
features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) | |
return features | |
class PairwiseDataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): | |
r""" | |
Data collator for pairwise data. | |
""" | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: | |
r""" | |
Pads batched data to the longest sequence in the batch. | |
We generate 2 * n examples where the first n examples represent chosen examples and | |
the last n examples represent rejected examples. | |
""" | |
concatenated_features = [] | |
for key in ("chosen", "rejected"): | |
for feature in features: | |
target_feature = { | |
"input_ids": feature["{}_input_ids".format(key)], | |
"attention_mask": feature["{}_attention_mask".format(key)], | |
"labels": feature["{}_labels".format(key)], | |
"images": feature["images"], | |
"videos": feature["videos"], | |
} | |
concatenated_features.append(target_feature) | |
return super().__call__(concatenated_features) | |
class KTODataCollatorWithPadding(MultiModalDataCollatorForSeq2Seq): | |
r""" | |
Data collator for KTO data. | |
""" | |
def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, "torch.Tensor"]: | |
target_features = [] | |
kl_features = [] | |
kto_tags = [] | |
for feature in features: | |
target_feature = { | |
"input_ids": feature["input_ids"], | |
"attention_mask": feature["attention_mask"], | |
"labels": feature["labels"], | |
"images": feature["images"], | |
"videos": feature["videos"], | |
} | |
kl_feature = { | |
"input_ids": feature["kl_input_ids"], | |
"attention_mask": feature["kl_attention_mask"], | |
"labels": feature["kl_labels"], | |
"images": feature["images"], | |
"videos": feature["videos"], | |
} | |
target_features.append(target_feature) | |
kl_features.append(kl_feature) | |
kto_tags.append(feature["kto_tags"]) | |
batch = super().__call__(target_features) | |
kl_batch = super().__call__(kl_features) | |
batch["kl_input_ids"] = kl_batch["input_ids"] | |
batch["kl_attention_mask"] = kl_batch["attention_mask"] | |
batch["kl_labels"] = kl_batch["labels"] | |
if "token_type_ids" in kl_batch: | |
batch["kl_token_type_ids"] = kl_batch["token_type_ids"] | |
batch["kto_tags"] = torch.tensor(kto_tags) | |
return batch | |