|
import json |
|
from typing import TYPE_CHECKING, List, Literal, Union |
|
|
|
from datasets import Dataset, concatenate_datasets |
|
from distilabel.llms.huggingface import InferenceEndpointsLLM |
|
from distilabel.pipeline import Pipeline |
|
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput |
|
from distilabel.steps.tasks import TextGeneration |
|
from typing_extensions import override |
|
|
|
CHOSEN_TEMPLATE = """ |
|
You are provide with a conversation between a human and an AI assistant. |
|
The final message is of poor quality positively. Your task is to regenerate one of high quality. |
|
{% for message in conversation %} |
|
{{ message["role"] }}: {{ message["content"] }} |
|
{% endfor %} |
|
High quality response: |
|
""".rstrip() |
|
|
|
CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response." |
|
|
|
REJECT_TEMPLATE = """ |
|
You are provide with a conversation between a human and an AI assistant. |
|
The final message is of high quality positively. Your task is to regenerate one of poor quality. |
|
{% for message in conversation %} |
|
{{ message["role"] }}: {{ message["content"] }} |
|
{% endfor %} |
|
Poor quality response: |
|
""".rstrip() |
|
|
|
REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response." |
|
|
|
|
|
class FilterConversationRatings(Step): |
|
"""Filters conversations based on the rating of the last message.""" |
|
|
|
target_column: Union[Literal["chosen"], Literal["rejected"]] |
|
batch_size: int = 5 |
|
|
|
@override |
|
def process(self, dataset: StepInput) -> "GeneratorStepOutput": |
|
|
|
column_rating_map = { |
|
"chosen": 1, |
|
"rejected": -1, |
|
} |
|
|
|
target_rating = column_rating_map[self.target_column] |
|
|
|
for batch_start in range(0, len(dataset), self.batch_size): |
|
batch = dataset[batch_start : batch_start + self.batch_size] |
|
filtered_batch = [] |
|
for conversation in batch: |
|
for row in batch: |
|
_conversation = row["conversation"] |
|
conversation = None |
|
for idx, message in enumerate(_conversation, 1): |
|
if not isinstance(message["rating"], int): |
|
continue |
|
if message["rating"] == target_rating: |
|
conversation = _conversation[:idx] |
|
break |
|
if conversation: |
|
filtered_batch.append({"conversation": conversation}) |
|
yield filtered_batch |
|
|
|
@property |
|
def outputs(self) -> "StepColumns": |
|
return ["conversation"] |
|
|
|
|
|
class AppendToConversationStep(Step): |
|
"""Appends a generated message to a conversation.""" |
|
|
|
@property |
|
def inputs(self) -> "StepColumns": |
|
return ["generation", "conversation"] |
|
|
|
@property |
|
def outputs(self) -> "StepColumns": |
|
return ["generated_conversation", "conversation"] |
|
|
|
def process(self, inputs: StepInput) -> "StepOutput": |
|
|
|
for input in inputs: |
|
if not input["generation"]: |
|
continue |
|
if not input["conversation"]: |
|
continue |
|
input["generated_conversation"] = [ |
|
{"role": message["role"], "content": message["content"]} |
|
for message in input["conversation"][:-1] |
|
] + [{"role": "assistant", "content": input["generation"]}] |
|
input["conversation"] = [ |
|
{"role": message["role"], "content": message["content"]} |
|
for message in input["conversation"] |
|
] |
|
yield inputs |
|
|
|
|
|
with Pipeline( |
|
name="conversation_rejection", |
|
description="Generate a chosen response to a rejected conversation.", |
|
) as rejection_pipeline: |
|
|
|
rejected_dataset = FilterConversationRatings(target_column="rejected") |
|
|
|
chosen_text_gen = TextGeneration( |
|
llm=InferenceEndpointsLLM( |
|
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", |
|
), |
|
system_prompt=CHOSEN_SYSTEM_PROMPT, |
|
template=CHOSEN_TEMPLATE, |
|
columns=["conversation"], |
|
) |
|
|
|
append_chosen = AppendToConversationStep( |
|
output_mappings={ |
|
"generated_conversation": "chosen", |
|
"conversation": "rejected", |
|
}, |
|
) |
|
|
|
keep_columns = KeepColumns( |
|
columns=["chosen", "rejected"], |
|
) |
|
|
|
rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns |
|
|
|
with Pipeline( |
|
name="conversation_chosen", |
|
description="Generate a rejected response to a chosen conversation.", |
|
) as chosen_pipeline: |
|
|
|
chosen_dataset = FilterConversationRatings(target_column="chosen") |
|
|
|
rejected_text_gen = TextGeneration( |
|
llm=InferenceEndpointsLLM( |
|
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", |
|
), |
|
system_prompt=REJECT_SYSTEM_PROMPT, |
|
template=REJECT_TEMPLATE, |
|
columns=["conversation"], |
|
) |
|
append_rejected = AppendToConversationStep( |
|
output_mappings={ |
|
"generated_conversation": "rejected", |
|
"conversation": "chosen", |
|
}, |
|
) |
|
keep_columns = KeepColumns( |
|
columns=["chosen", "rejected"], |
|
) |
|
chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns |
|
|
|
if __name__ == "__main__": |
|
|
|
dataset_path = "example_data.json" |
|
data = json.load(open(dataset_path)) |
|
|
|
dataset = Dataset.from_list(data) |
|
rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False) |
|
chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False) |
|
|
|
dataset = concatenate_datasets( |
|
dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]] |
|
) |
|
|