Flux9665's picture
update to current version
6a79837
import json
import os
import pickle
import random
import kan
import torch
from tqdm import tqdm
from Modules.ToucanTTS.InferenceToucanTTS import ToucanTTS
from Utility.utils import load_json_from_path
class MetricsCombiner(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.scoring_function = kan.KAN(width=[3, 5, 1], grid=5, k=5, seed=m)
def forward(self, x):
return self.scoring_function(x.squeeze())
class EnsembleModel(torch.nn.Module):
def __init__(self, models):
super().__init__()
self.models = models
def forward(self, x):
distances = list()
for model in self.models:
distances.append(model(x))
return sum(distances) / len(distances)
def create_learned_cache(model_path, cache_root="."):
checkpoint = torch.load(model_path, map_location='cpu')
embedding_provider = ToucanTTS(weights=checkpoint["model"], config=checkpoint["config"]).encoder.language_embedding
embedding_provider.requires_grad_(False)
language_list = load_json_from_path(os.path.join(cache_root, "supervised_languages.json"))
tree_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_tree_dist.json")
map_lookup_path = os.path.join(cache_root, "lang_1_to_lang_2_to_map_dist.json")
asp_dict_path = os.path.join(cache_root, "asp_dict.pkl")
if not os.path.exists(tree_lookup_path) or not os.path.exists(map_lookup_path):
raise FileNotFoundError("Please ensure the caches exist!")
if not os.path.exists(asp_dict_path):
raise FileNotFoundError(f"{asp_dict_path} must be downloaded separately.")
tree_dist = load_json_from_path(tree_lookup_path)
map_dist = load_json_from_path(map_lookup_path)
with open(asp_dict_path, 'rb') as dictfile:
asp_sim = pickle.load(dictfile)
lang_list = list(asp_sim.keys())
largest_value_map_dist = 0.0
for _, values in map_dist.items():
for _, value in values.items():
largest_value_map_dist = max(largest_value_map_dist, value)
iso_codes_to_ids = load_json_from_path(os.path.join(cache_root, "iso_lookup.json"))[-1]
train_set = language_list
batch_size = 128
model_list = list()
print_intermediate_results = False
# ensemble preparation
n_models = 5
print(f"Training ensemble of {n_models} models for learned distance metric.")
for m in range(n_models):
model_list.append(MetricsCombiner(m))
optim = torch.optim.Adam(model_list[-1].parameters(), lr=0.0005)
running_loss = list()
for epoch in tqdm(range(35), desc=f"MetricsCombiner {m + 1}/{n_models} - Epoch"):
for i in range(1000):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
# ok now we have a batch prepared. Time to feed it to the model.
scores = model_list[-1](torch.stack(metric_distance_batch).squeeze())
if print_intermediate_results:
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze(), reduction="none")
loss = loss / (torch.stack(embedding_distance_batch).squeeze() + 0.0001)
loss = loss.mean()
running_loss.append(loss.item())
optim.zero_grad()
loss.backward()
optim.step()
print("\n\n")
print(sum(running_loss) / len(running_loss))
print("\n\n")
running_loss = list()
# model_list[-1].scoring_function.plot(folder=f"kan_vis_{m}", beta=5000)
# plt.show()
# Time to see if the final ensemble is any good
ensemble = EnsembleModel(model_list)
running_loss = list()
for i in range(100):
# we have no dataloader, so first we build a batch
embedding_distance_batch = list()
metric_distance_batch = list()
for _ in range(batch_size):
lang_1 = random.sample(train_set, 1)[0]
lang_2 = random.sample(train_set, 1)[0]
embedding_distance_batch.append(torch.nn.functional.mse_loss(embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_1]])).squeeze(), embedding_provider(torch.LongTensor([iso_codes_to_ids[lang_2]])).squeeze()))
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance_batch.append(torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32))
scores = ensemble(torch.stack(metric_distance_batch).squeeze())
print("==================================")
print(scores.detach().squeeze()[:9])
print(torch.stack(embedding_distance_batch).squeeze()[:9])
loss = torch.nn.functional.mse_loss(scores.squeeze(), torch.stack(embedding_distance_batch).squeeze())
running_loss.append(loss.item())
print("\n\n")
print(sum(running_loss) / len(running_loss))
language_to_language_to_learned_distance = dict()
for lang_1 in tqdm(tree_dist):
for lang_2 in tree_dist:
try:
if lang_2 in language_to_language_to_learned_distance:
if lang_1 in language_to_language_to_learned_distance[lang_2]:
continue # it's symmetric
if lang_1 not in language_to_language_to_learned_distance:
language_to_language_to_learned_distance[lang_1] = dict()
try:
_tree_dist = tree_dist[lang_2][lang_1]
except KeyError:
_tree_dist = tree_dist[lang_1][lang_2]
try:
_map_dist = map_dist[lang_2][lang_1] / largest_value_map_dist
except KeyError:
_map_dist = map_dist[lang_1][lang_2] / largest_value_map_dist
_asp_dist = 1.0 - asp_sim[lang_1][lang_list.index(lang_2)]
metric_distance = torch.tensor([_tree_dist, _map_dist, _asp_dist], dtype=torch.float32)
with torch.inference_mode():
predicted_distance = ensemble(metric_distance.unsqueeze(0)).squeeze()
language_to_language_to_learned_distance[lang_1][lang_2] = predicted_distance.item()
except ValueError:
continue
except KeyError:
continue
with open(os.path.join(cache_root, 'lang_1_to_lang_2_to_learned_dist.json'), 'w', encoding='utf-8') as f:
json.dump(language_to_language_to_learned_distance, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
create_learned_cache("../../Models/ToucanTTS_Meta/best.pt")