CosyVoice commited on
Commit
6a3e442
1 Parent(s): ee9e87b

keep only embedding mean as spk embedding

Browse files
cosyvoice/dataset/processor.py CHANGED
@@ -167,7 +167,7 @@ def parse_embedding(data, normalize, mode='train'):
167
  """
168
  for sample in data:
169
  sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
170
- sample['spk_embedding'] = torch.stack([torch.tensor(i, dtype=torch.float32) for i in sample['spk_embedding']], dim=0).mean(dim=0)
171
  if normalize:
172
  sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
173
  sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
 
167
  """
168
  for sample in data:
169
  sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
170
+ sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
171
  if normalize:
172
  sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
173
  sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
tools/extract_embedding.py CHANGED
@@ -53,6 +53,8 @@ def main(args):
53
  if spk not in spk2embedding:
54
  spk2embedding[spk] = []
55
  spk2embedding[spk].append(embedding)
 
 
56
 
57
  torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
58
  torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
 
53
  if spk not in spk2embedding:
54
  spk2embedding[spk] = []
55
  spk2embedding[spk].append(embedding)
56
+ for k, v in spk2embedding.items():
57
+ spk2embedding[k] = torch.tensor(v).mean(dim=0, keepdim=True).tolist()
58
 
59
  torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
  torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))