leaderboard / sharegpt /compare_distributions.py
ohjuny's picture
Filter Conversations Not Started By "human" (#35)
2afd5bf unverified
raw
history blame
2.03 kB
import json
import matplotlib.pyplot as plt
import numpy as np
from transformers import (
AutoTokenizer,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
# Open datasets
file_paths = ["ShareGPT_V3_filtered.json", "ShareGPT_V3_filtered_500.json"]
names = [file_path[:-5] for file_path in file_paths]
data_lists = []
for file_path in file_paths:
with open(file_path, "r", encoding="utf-8") as file:
data_list = json.load(file)
data_lists.append(data_list)
for name, data_list in zip(names, data_lists):
print(f"{name}: {len(data_list)}")
# Get prompt lengths using tokenizer
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
all_prompts = [
[data["conversations"][0]["value"] for data in data_lists if data["conversations"][0]["from"] == "human"]
for data_lists in data_lists
]
all_token_ids_per_prompts = [tokenizer(prompts).input_ids for prompts in all_prompts]
all_prompt_lens = [
[len(token_ids) for token_ids in token_ids_per_prompt]
for token_ids_per_prompt in all_token_ids_per_prompts
]
# Plotting the histograms
for name, prompt_lens in zip(names, all_prompt_lens):
plt.hist(
prompt_lens,
bins=range(min(prompt_lens), max(prompt_lens) + 1),
edgecolor="black",
)
plt.xlabel("Prompt Length (number of tokens)")
plt.ylabel("Frequency")
plt.title(f"Histogram of {name}")
plt.savefig(f"{name}_distribution.png")
plt.close()
# Plotting the CDF
for name, prompt_lens in zip(names, all_prompt_lens):
values, counts = np.unique(prompt_lens, return_counts=True)
relative_frequencies = counts / len(prompt_lens)
sorted_data = np.sort(values)
cumulative_frequencies = np.cumsum(relative_frequencies)
plt.step(sorted_data, cumulative_frequencies, where="post", label=name)
plt.title(f"Cumulative Distribution Function (CDF) Overlayed")
plt.xlabel("Prompt Length (number of tokens)")
plt.ylabel("Cumulative Probability")
plt.savefig(f"{name}_cdf.png")
plt.close()