Spaces:
Runtime error
Runtime error
import config as CFG | |
import json | |
from models import PoemTextModel | |
import torch | |
import random | |
from datasets import PoemTextDataset, get_transforms, CLIPDataset | |
from tqdm import tqdm | |
import numpy as np | |
class AvgMeter: | |
""" | |
Used to keep track of batch losses during training / validation. | |
... | |
Attributes: | |
----------- | |
name : str | |
count : int | |
number of data whose train/val loss has been metered | |
sum: int or float | |
sum of all losses metered | |
avg: int or float | |
average of metered losses | |
Methods: | |
-------- | |
reset(): | |
Sets count, sum and avg to 0. | |
update(val, count=1): | |
Updates loss sum, count and avg. | |
__repr__(): | |
string representation of this class. | |
""" | |
def __init__(self, name="Metric"): | |
"""Sets the name of the avg meter. sets avg, sum & count to 0.""" | |
self.name = name | |
self.reset() | |
def reset(self): | |
"""Sets avg, sum & count to 0.""" | |
self.avg, self.sum, self.count = [0] * 3 | |
def update(self, val, count=1): | |
"""Updates loss sum, count and avg using val and count (count of the val input)""" | |
self.count += count | |
self.sum += val * count | |
self.avg = self.sum / self.count | |
def __repr__(self): | |
"""String representation of this class""" | |
text = f"{self.name}: {self.avg:.4f}" | |
return text | |
def get_lr(optimizer): | |
"""Returns learning rate of the input optimizer""" | |
for param_group in optimizer.param_groups: | |
return param_group["lr"] | |
def get_datasets(): | |
""" | |
Returns train, validation & test split from a dataset json file specified using CFG.dataset_path. | |
This function first loads the file into a list of dict and shuffles them with CFG.random_seed seed, | |
then splits them using CFG.train_propotion & CFG.val_propotion. | |
Returns: | |
-------- | |
train_dataset: list of dict | |
Train split | |
val_dataset: list of dict | |
Validation split | |
test_dataset: list of dict | |
Test split | |
""" | |
with open(CFG.dataset_path, encoding="utf-8") as f: | |
dataset = json.load(f) | |
random.Random(CFG.random_seed).shuffle(dataset) | |
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test | |
train_dataset, val_dataset, test_dataset = np.split(dataset, | |
[int(CFG.train_propotion*len(dataset)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset))]) | |
return train_dataset, val_dataset, test_dataset | |
def build_loaders(dataset_dict, mode): | |
""" | |
Returns a torch Dataloader from a list of dictionaries (dataset_dict). | |
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. | |
Parameters: | |
----------- | |
dataset_dict: list of dict | |
the dataset to return a dataloader of. | |
mode: str ("train" or any other word) | |
if the mode is "train", dataloader will activate shuffling. | |
Returns: | |
-------- | |
dataloader: torch.utils.data.DataLoader | |
the torch Dataloader created from dataset_dict using PoemTextDataset and configs. | |
""" | |
dataset = PoemTextDataset( | |
dataset_dict | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
num_workers=CFG.num_workers, | |
shuffle=True if mode == "train" else False, | |
) | |
return dataloader | |
def get_clip_datasets(dataset_dict): | |
""" | |
(Used for clip model training) Returns train, validation & test split from input. | |
This function takes a list of dict as dataset and shuffles them with CFG.random_seed seed, | |
then splits them using CFG.train_propotion & CFG.val_propotion. | |
Parameters: | |
----------- | |
dataset_dict: list of dict | |
the input dataset | |
Returns: | |
-------- | |
train_dataset: list of dict | |
Train split | |
val_dataset: list of dict | |
Validation split | |
test_dataset: list of dict | |
Test split | |
""" | |
random.Random(CFG.random_seed).shuffle(dataset_dict) | |
# https://stackoverflow.com/questions/38250710/how-to-split-data-into-3-sets-train-validation-and-test | |
train_dataset, val_dataset, test_dataset = np.split(dataset_dict, | |
[int(CFG.train_propotion*len(dataset_dict)), int((CFG.train_propotion + CFG.val_propotion)*len(dataset_dict))]) | |
return train_dataset, val_dataset, test_dataset | |
def build_image_loaders(dataset_dict, mode): | |
""" | |
(Used for clip model training) Returns a torch Dataloader from a list of dictionaries (dataset_dict). | |
First makes a PoemTextDataset which is a torch Dataset object from dataset_dict and then instantiates a Dataloader. | |
Parameters: | |
----------- | |
dataset_dict: list of dict | |
the dataset to return a dataloader of. | |
mode: str ("train" or any other word) | |
if the mode is "train", dataloader will activate shuffling. | |
Returns: | |
-------- | |
dataloader: torch.utils.data.DataLoader | |
the torch Dataloader created from dataset_dict using CLIPDataset and configs. | |
""" | |
transforms = get_transforms(mode=mode) | |
dataset = CLIPDataset( | |
dataset_dict, transforms, is_image_poem_pair=False | |
) | |
dataloader = torch.utils.data.DataLoader( | |
dataset, | |
batch_size=CFG.batch_size, | |
num_workers=CFG.num_workers, | |
shuffle=True if mode == "train" else False, | |
) | |
return dataloader | |
def get_poem_embeddings(test_dataset, model=None): | |
""" | |
Returns embeddings of the poems existing in test_dataset. | |
Parameters: | |
----------- | |
test_dataset: list of dict | |
dataset to get poems from. each of its dictionaries must have a "beyt" key. | |
model: PoemTextModel, optional | |
The PoemTextModel model to get poem embeddings from. | |
If None is given, instantiates a new model (with all of its parts in pretrained settings) using configurations provided in config.py. | |
Returns: | |
-------- | |
model (PoemTextModel): The model used for creating poem embeddings | |
""" | |
test_loader = build_loaders(test_dataset, mode="test") # building a dataloder (which also tokenizes the poems) | |
if model == None: | |
model = PoemTextModel(True, False, True, False, poem_projection_pretrained=True, text_projection_pretrained=True).to(CFG.device) | |
model.eval() | |
poem_embeddings = [] | |
with torch.no_grad(): | |
for batch in tqdm(test_loader): | |
# get poem embeddings by passing tokenizer output of the poems | |
# to the model's poem encoder and projection | |
beyts = { | |
key: values.to(CFG.device) | |
for key, values in batch["beyt"].items() | |
} | |
if model.__class__.__name__ == "PoemTextModel": | |
poem_features = model.poem_encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) | |
poem_emb = model.poem_projection(poem_features) | |
poem_embeddings.append(poem_emb) | |
elif model.__class__.__name__ == "CLIPModel": | |
poem_features = model.encoder(input_ids=beyts["input_ids"], attention_mask=beyts["attention_mask"]) | |
poem_emb = model.text_projection(poem_features) | |
poem_embeddings.append(poem_emb) | |
else: | |
raise #not a right model to use! | |
return model, torch.cat(poem_embeddings) | |