open-human-feedback-chat / data /generate_dpo.py
burtenshaw
respond to feedback on prompts
01ba912
import json
from typing import TYPE_CHECKING, List, Literal, Union
from datasets import Dataset, concatenate_datasets
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
from distilabel.steps.tasks import TextGeneration
from typing_extensions import override
CHOSEN_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of poor quality positively. Your task is to regenerate one of high quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
High quality response:
""".rstrip()
CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response."
REJECT_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of high quality positively. Your task is to regenerate one of poor quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
Poor quality response:
""".rstrip()
REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response."
class FilterConversationRatings(Step):
"""Filters conversations based on the rating of the last message."""
target_column: Union[Literal["chosen"], Literal["rejected"]]
batch_size: int = 5
@override
def process(self, dataset: StepInput) -> "GeneratorStepOutput":
column_rating_map = {
"chosen": 1,
"rejected": -1,
}
target_rating = column_rating_map[self.target_column]
for batch_start in range(0, len(dataset), self.batch_size):
batch = dataset[batch_start : batch_start + self.batch_size]
filtered_batch = []
for conversation in batch:
for row in batch:
_conversation = row["conversation"]
conversation = None
for idx, message in enumerate(_conversation, 1):
if not isinstance(message["rating"], int):
continue
if message["rating"] == target_rating:
conversation = _conversation[:idx]
break
if conversation:
filtered_batch.append({"conversation": conversation})
yield filtered_batch
@property
def outputs(self) -> "StepColumns":
return ["conversation"]
class AppendToConversationStep(Step):
"""Appends a generated message to a conversation."""
@property
def inputs(self) -> "StepColumns":
return ["generation", "conversation"]
@property
def outputs(self) -> "StepColumns":
return ["generated_conversation", "conversation"]
def process(self, inputs: StepInput) -> "StepOutput":
for input in inputs:
if not input["generation"]:
continue
if not input["conversation"]:
continue
input["generated_conversation"] = [
{"role": message["role"], "content": message["content"]}
for message in input["conversation"][:-1]
] + [{"role": "assistant", "content": input["generation"]}]
input["conversation"] = [
{"role": message["role"], "content": message["content"]}
for message in input["conversation"]
]
yield inputs
with Pipeline(
name="conversation_rejection",
description="Generate a chosen response to a rejected conversation.",
) as rejection_pipeline:
rejected_dataset = FilterConversationRatings(target_column="rejected")
chosen_text_gen = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
system_prompt=CHOSEN_SYSTEM_PROMPT,
template=CHOSEN_TEMPLATE,
columns=["conversation"],
)
append_chosen = AppendToConversationStep(
output_mappings={
"generated_conversation": "chosen",
"conversation": "rejected",
},
)
keep_columns = KeepColumns(
columns=["chosen", "rejected"],
)
rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns
with Pipeline(
name="conversation_chosen",
description="Generate a rejected response to a chosen conversation.",
) as chosen_pipeline:
chosen_dataset = FilterConversationRatings(target_column="chosen")
rejected_text_gen = TextGeneration(
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
),
system_prompt=REJECT_SYSTEM_PROMPT,
template=REJECT_TEMPLATE,
columns=["conversation"],
)
append_rejected = AppendToConversationStep(
output_mappings={
"generated_conversation": "rejected",
"conversation": "chosen",
},
)
keep_columns = KeepColumns(
columns=["chosen", "rejected"],
)
chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns
if __name__ == "__main__":
dataset_path = "example_data.json"
data = json.load(open(dataset_path))
dataset = Dataset.from_list(data)
rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)
dataset = concatenate_datasets(
dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
)