Spaces:
Sleeping
Sleeping
File size: 1,468 Bytes
2f56479 87b7a45 2f56479 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import dataclasses
import functools
import logging
import os
import pickle
import pprint
import random
from typing import List
EMPTY_DATA_PATH = "tangram_pngs/"
SPLIT_PATH = "dataset_splits/"
@dataclasses.dataclass(frozen=True)
class GameConfig:
speaker_context: List[str]
listener_context: List[str]
targets: List[str]
def generate_game_config() -> GameConfig:
corpus = _get_data()
context = random.sample(corpus, 10)
num_targets = random.randint(3, 5)
targets = random.sample(context, num_targets)
listener_order = list(range(10))
random.shuffle(listener_order)
config = GameConfig(
speaker_context=context,
listener_context=[context[i] for i in listener_order],
targets=targets,
)
logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
return config
@functools.cache
def _get_data(hb_split: bool=True):
if not hb_split:
# 1013 images
paths = os.listdir(EMPTY_DATA_PATH)
else:
# 912 images
with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
paths = pickle.load(f)
with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
paths += pickle.load(f)
paths = [path + ".png" for path in paths]
dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
return paths
|