Spaces:
Sleeping
Sleeping
import dataclasses | |
import functools | |
import logging | |
import os | |
import pickle | |
import pprint | |
import random | |
from typing import List | |
EMPTY_DATA_PATH = "tangram_pngs/" | |
SPLIT_PATH = "dataset_splits/" | |
class GameConfig: | |
speaker_context: List[str] | |
listener_context: List[str] | |
targets: List[str] | |
def generate_game_config() -> GameConfig: | |
corpus = _get_data() | |
context = random.sample(corpus, 10) | |
num_targets = random.randint(3, 5) | |
targets = random.sample(context, num_targets) | |
listener_order = list(range(10)) | |
random.shuffle(listener_order) | |
config = GameConfig( | |
speaker_context=context, | |
listener_context=[context[i] for i in listener_order], | |
targets=targets, | |
) | |
logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}") | |
return config | |
def _get_data(hb_split: bool=True): | |
if not hb_split: | |
# 1013 images | |
paths = os.listdir(EMPTY_DATA_PATH) | |
else: | |
# 912 images | |
with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f: | |
paths = pickle.load(f) | |
with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f: | |
paths += pickle.load(f) | |
paths = [path + ".png" for path in paths] | |
dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"] | |
paths = [path for path in paths if path != ".DS_Store" and path not in dup_images] | |
return paths | |