File size: 6,318 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import sys
import time
p = os.path.split(os.path.dirname(os.path.abspath(__file__)))[0]
sys.path.append(p)
import argparse
import numpy as np
import tensorflow as tf
tf.compat.v1.disable_v2_behavior()

from utils.hparams import HParams
from models import get_model
import torch


set_size = 200
threshold = 100


def fast_cosine_dist(source_feats, matching_pool):
    source_norms = torch.norm(source_feats, p=2, dim=-1)
    matching_norms = torch.norm(matching_pool, p=2, dim=-1)
    dotprod = -torch.cdist(source_feats[None], matching_pool[None], p=2)[0]**2 + source_norms[:, None]**2 + matching_norms[None]**2
    dotprod /= 2

    dists = 1 - ( dotprod / (source_norms[:, None] * matching_norms[None]) )
    return dists


def evaluate(batch, model):
    sample = model.execute(model.sample, batch)
    return sample


def prematch(path, expanded):
    uttrs_from_same_spk = sorted(list(path.parent.rglob('**/*.pt')))
    uttrs_from_same_spk.remove(path)
    candidates = []
    for each in uttrs_from_same_spk:
        candidates.append(torch.load(each))
    candidates = torch.cat(candidates,0)
    candidates = torch.cat([candidates, torch.tensor(expanded)], 0)
    source_feats = torch.load(path)
    source_feats=source_feats.to(torch.float32)
    dists = fast_cosine_dist(source_feats.cpu(), candidates.cpu()).cpu()
    best = dists.topk(k=args.topk, dim=-1, largest=False) # (src_len, 4)
    out_feats = candidates[best.indices].mean(dim=1) # (N, dim)
    return out_feats
    

def single_expand(path, model, num_samples, seed=1234, out_path=None):
    np.random.seed(seed)
    tf.compat.v1.set_random_seed(seed)
    # test
    matching_set = torch.load(path, map_location=torch.device('cpu')).numpy()
    matching_set = matching_set / 10
    matching_size = matching_set.shape[0]
    new_samples = []
    cur_num_samples = 0
    while cur_num_samples < num_samples:
        batch = dict()
        if matching_size < threshold:
            num_new_samples = set_size - matching_size
            padded_data = np.zeros((num_new_samples, matching_set.shape[1]))
            batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...]
            batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...]
            batch['m'] = np.ones_like(batch['b'])
            sample = evaluate(batch, model)
            new_sample = sample[0,matching_size:] * 10
            cur_num_samples += num_new_samples
        else:
            num_new_samples = set_size - threshold
            ind = np.random.choice(matching_size, threshold, replace=False)
            padded_data = np.zeros((num_new_samples, matching_set.shape[1]))
            obs_data = matching_set[ind]
            batch['x'] = np.concatenate([obs_data, padded_data], 0)[None, ...]
            batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...]
            batch['m'] = np.ones_like(batch['b'])
            sample = evaluate(batch, model)
            new_sample = sample[0,num_new_samples:,:] * 10
            cur_num_samples += num_new_samples
        
        new_samples.append(new_sample)
    new_samples = np.concatenate(new_samples, 0)
    new_samples = new_samples[:num_samples]
    if out_path:
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        np.save(out_path, new_samples)
    return new_samples


def single_expand_fast(path):
    # test
    matching_set = torch.load(path).cpu().numpy()
    matching_set = matching_set / 10
    matching_size = matching_set.shape[0]
    batch = dict()
    if matching_size < threshold:
        num_new_samples = set_size - matching_size
    else:
        num_new_samples = set_size - threshold
    batch_size = int(np.ceil(args.num_samples // num_new_samples))
    if matching_size < threshold:
        padded_data = np.zeros((num_new_samples, matching_set.shape[1]))
        batch['b'] = np.concatenate([np.ones_like(matching_set), np.zeros_like(padded_data)], 0)[None, ...]
        batch['x'] = np.concatenate([matching_set, padded_data], axis=0)[None, ...]
        batch['b'] = np.tile(batch['b'], (batch_size, 1, 1))
        batch['x'] = np.tile(batch['b'], (batch_size, 1, 1))
        batch['m'] = np.ones_like(batch['b'])
        sample = evaluate(batch, model)
        new_samples = sample[:,matching_size:, :] * 10
        new_samples = new_samples.reshape((-1, new_samples.shape[-1]))
    else:
        padded_data = np.zeros((num_new_samples, matching_set.shape[1]))
        batch['x'] = []
        for i in range(batch_size):
            ind = np.random.choice(matching_size, threshold, replace=False)
            obs_data = matching_set[ind]
            batch['x'].append(np.concatenate([obs_data, padded_data], 0)[None, ...])
        batch['x'] = np.concatenate(batch['x'], 0)
        batch['b'] = np.concatenate([np.ones_like(obs_data), np.zeros_like(padded_data)], 0)[None, ...]
        batch['b'] = np.tile(batch['b'], (batch_size, 1, 1))
        batch['m'] = np.ones_like(batch['b'])
        sample = evaluate(batch, model)
        new_samples = sample[:,matching_size:, :] * 10
        new_samples = new_samples.reshape((-1, new_samples.shape[-1]))
    new_samples = new_samples[:args.num_samples]
    return new_samples

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg_file', type=str)
    parser.add_argument('--seed', type=int, default=1234)
    parser.add_argument('--gpu', type=str, default='0')
    parser.add_argument('--num_samples', type=int, default=100)
    parser.add_argument('--path', type=str, default="matching_set.pt")
    parser.add_argument('--out_path', type=str, default="expanded_set.pt")
    parser.add_argument('--topk', type=int, default=4)
    args = parser.parse_args()
    params = HParams(args.cfg_file)
    # modify config

    t0 = time.time()
    # model
    model = get_model(params)
    model.load()
    t1 = time.time()
    print(f"{t1-t0:.2f}s to load the model")

    os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
    path = args.path
    if path.endswith(".pt"):
        t0 = time.time()
        expanded = single_expand(path, model, args.num_samples, args.seed, args.out_path)
        t1 = time.time()
        print(f"{t1-t0:.2f}s to expand the set")