open-human-feedback-chat / ml /trl_rlhf_data.py
burtenshaw
migrate all ml files into subdir
e9484c6
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from dataclasses import dataclass
from typing import Dict, List, Optional
from datasets import load_dataset
from transformers import HfArgumentParser
@dataclass
class ScriptArguments:
r"""
Arguments for the script.
Args:
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the dataset to the Hugging Face Hub.
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`):
Hugging Face repository ID to push the dataset to.
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
Number of workers to use for dataset processing.
"""
push_to_hub: bool = False
repo_id: str = "trl-lib/hh-rlhf-helpful-base"
dataset_num_proc: Optional[int] = None
def common_start(str1: str, str2: str) -> str:
# Zip the two strings and iterate over them together
common_chars = []
for c1, c2 in zip(str1, str2):
if c1 == c2:
common_chars.append(c1)
else:
break
# Join the common characters and return as a string
return "".join(common_chars)
def extract_dialogue(example: str) -> List[Dict[str, str]]:
# Extract the prompt, which corresponds to the common start of the chosen and rejected dialogues
prompt_text = common_start(example["chosen"], example["rejected"])
# The chosen and rejected may share a common start, so we need to remove the common part
if not prompt_text.endswith("\n\nAssistant: "):
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: "
# Extract the chosen and rejected lines
chosen_line = example["chosen"][len(prompt_text) :]
rejected_line = example["rejected"][len(prompt_text) :]
# Remove the generation prompt ("\n\nAssistant: ") from the prompt
prompt_text = prompt_text[: -len("\n\nAssistant: ")]
# Split the string at every occurrence of "Human: " or "Assistant: "
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text)
# Remove the first element as it's empty
prompt_lines = prompt_lines[1:]
prompt = []
for idx in range(0, len(prompt_lines), 2):
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant"
content = prompt_lines[idx + 1]
prompt.append({"role": role, "content": content})
# Remove the prompt from the chosen and rejected dialogues
chosen = [{"role": "assitant", "content": chosen_line}]
rejected = [{"role": "assistant", "content": rejected_line}]
return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
def runner(arguments):
parser = HfArgumentParser(arguments)
script_args = parser.parse_args_into_dataclasses()[0]
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc)
return
# if script_args.push_to_hub:
# dataset.push_to_hub(script_args.repo_id)