simon_says_v2 / dataset_maker.py
ericmichael's picture
first commit
0079b8d
import yaml
import json
from sklearn.model_selection import train_test_split
# Load the YAML file
with open("data/data.yaml", "r") as f:
data = yaml.safe_load(f)
# Separate the data by category
easy = [item for item in data if item["category"] == "easy"]
medium = [item for item in data if item["category"] == "medium"]
hard = [item for item in data if item["category"] == "hard"]
# Split each category into validation and test sets
easy_val, easy_test = train_test_split(easy, test_size=0.5, random_state=42)
medium_val, medium_test = train_test_split(medium, test_size=0.5, random_state=42)
hard_val, hard_test = train_test_split(hard, test_size=0.5, random_state=42)
# Combine the validation and test sets
validation = easy_val + medium_val + hard_val
test = easy_test + medium_test + hard_test
# Write the validation set to a JSON file
with open("validation.json", "w") as f:
json.dump(validation, f)
# Write the test set to a JSON file
with open("test.json", "w") as f:
json.dump(test, f)