|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Optional |
|
|
|
from datasets import load_dataset |
|
from transformers import HfArgumentParser |
|
|
|
|
|
@dataclass |
|
class ScriptArguments: |
|
r""" |
|
Arguments for the script. |
|
|
|
Args: |
|
push_to_hub (`bool`, *optional*, defaults to `False`): |
|
Whether to push the dataset to the Hugging Face Hub. |
|
repo_id (`str`, *optional*, defaults to `"trl-lib/hh-rlhf-helpful-base"`): |
|
Hugging Face repository ID to push the dataset to. |
|
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`): |
|
Number of workers to use for dataset processing. |
|
""" |
|
|
|
push_to_hub: bool = False |
|
repo_id: str = "trl-lib/hh-rlhf-helpful-base" |
|
dataset_num_proc: Optional[int] = None |
|
|
|
|
|
def common_start(str1: str, str2: str) -> str: |
|
|
|
common_chars = [] |
|
for c1, c2 in zip(str1, str2): |
|
if c1 == c2: |
|
common_chars.append(c1) |
|
else: |
|
break |
|
|
|
return "".join(common_chars) |
|
|
|
|
|
def extract_dialogue(example: str) -> List[Dict[str, str]]: |
|
|
|
prompt_text = common_start(example["chosen"], example["rejected"]) |
|
|
|
|
|
if not prompt_text.endswith("\n\nAssistant: "): |
|
prompt_text = prompt_text[: prompt_text.rfind("\n\nAssistant: ")] + "\n\nAssistant: " |
|
|
|
|
|
chosen_line = example["chosen"][len(prompt_text) :] |
|
rejected_line = example["rejected"][len(prompt_text) :] |
|
|
|
|
|
prompt_text = prompt_text[: -len("\n\nAssistant: ")] |
|
|
|
|
|
prompt_lines = re.split(r"(\n\nAssistant: |\n\nHuman: )", prompt_text) |
|
|
|
|
|
prompt_lines = prompt_lines[1:] |
|
|
|
prompt = [] |
|
for idx in range(0, len(prompt_lines), 2): |
|
role = "user" if prompt_lines[idx] == "\n\nHuman: " else "assistant" |
|
content = prompt_lines[idx + 1] |
|
prompt.append({"role": role, "content": content}) |
|
|
|
|
|
chosen = [{"role": "assitant", "content": chosen_line}] |
|
rejected = [{"role": "assistant", "content": rejected_line}] |
|
|
|
return {"prompt": prompt, "chosen": chosen, "rejected": rejected} |
|
|
|
|
|
def runner(arguments): |
|
parser = HfArgumentParser(arguments) |
|
script_args = parser.parse_args_into_dataclasses()[0] |
|
|
|
dataset = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base") |
|
dataset = dataset.map(extract_dialogue, num_proc=script_args.dataset_num_proc) |
|
return |
|
|
|
|
|
|
|
|