Spaces:
Runtime error
Runtime error
from utils import get_datasets, build_loaders | |
from models import PoemTextModel | |
from train import train, test | |
from metrics import calc_metrics | |
from inference import predict_poems_from_text | |
from utils import get_poem_embeddings | |
import config as CFG | |
import json | |
def main(): | |
""" | |
Creates a PoemTextModel based on configs and trains, tests and outputs some examples of its prediction. | |
""" | |
# get dataset from dataset_path (the same datasets as the train, val and test dataset files in the data directory is made) | |
train_dataset, val_dataset, test_dataset = get_datasets() | |
train_loader = build_loaders(train_dataset, mode="train") | |
valid_loader = build_loaders(val_dataset, mode="valid") | |
# train a PoemTextModel and write its loss history in a file | |
model = PoemTextModel(poem_encoder_pretrained=True, text_encoder_pretrained=True).to(CFG.device) | |
model, loss_history = train(model, train_loader, valid_loader) | |
with open('loss_history_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
f.write(json.dumps(loss_history, indent= 4)) | |
# compute accuracy, mean rank and MRR using test set and write them in a file | |
model.eval() | |
print("Accuracy on test set: ", test(model, test_dataset)) | |
metrics = calc_metrics(test_dataset, model) | |
print('mean rank: ', metrics["mean_rank"]) | |
print('mean reciprocal rank (MRR)', metrics["mean_reciprocal_rank_(MRR)"]) | |
with open('test_metrics_{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
f.write(json.dumps(metrics, indent= 4)) | |
# Inference: Output some example predictions and write them in a file | |
print("_"*20) | |
print("Output Examples from test set") | |
model, poem_embeddings = get_poem_embeddings(test_dataset, model) | |
example = {} | |
for i, test_data in enumerate(test_dataset[:100]): | |
example[i] = {'Text': test_data["text"], 'True Beyt': test_data["beyt"], "Predicted Beyt":predict_poems_from_text(model, poem_embeddings, test_data["text"], [data['beyt'] for data in test_dataset], n=10)} | |
for i in range(10): | |
print("Text: ", example[i]['Text']) | |
print("True Beyt: ", example[i]['True Beyt']) | |
print("predicted Beyts: \n\t", "\n\t".join(example[i]["Predicted Beyt"])) | |
with open('example_output__{}_{}.json'.format(CFG.poem_encoder_model, CFG.text_encoder_model),'w', encoding="utf-8") as f: | |
f.write(json.dumps(example, ensure_ascii=False, indent= 4)) | |
if __name__ == "__main__": | |
main() |