File size: 1,468 Bytes
2f56479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87b7a45
 
2f56479
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import dataclasses
import functools
import logging
import os
import pickle
import pprint
import random
from typing import List

EMPTY_DATA_PATH =  "tangram_pngs/"
SPLIT_PATH = "dataset_splits/"


@dataclasses.dataclass(frozen=True)
class GameConfig:
    speaker_context: List[str]
    listener_context: List[str]
    targets: List[str]


def generate_game_config() -> GameConfig:
    corpus = _get_data()
    context = random.sample(corpus, 10)
    num_targets = random.randint(3, 5)
    targets = random.sample(context, num_targets)
    listener_order = list(range(10))
    random.shuffle(listener_order)

    config = GameConfig(
        speaker_context=context,
        listener_context=[context[i] for i in listener_order],
        targets=targets,
    )
    logging.info(f"context_dict: {pprint.pformat(dataclasses.asdict(config))}")
    return config

@functools.cache
def _get_data(hb_split: bool=True):
    if not hb_split:
        # 1013 images
        paths = os.listdir(EMPTY_DATA_PATH)
    else:
        # 912 images
        with open(os.path.join(SPLIT_PATH, "test_imgs.pkl"), 'rb') as f:
            paths = pickle.load(f)
        with open(os.path.join(SPLIT_PATH, "train_imgs.pkl"), 'rb') as f:
            paths += pickle.load(f)
        paths = [path + ".png" for path in paths]
    dup_images = ["page6-51.png", "page6-66.png", "page4-170.png"]
    paths = [path for path in paths if path != ".DS_Store" and path not in dup_images]
    return paths