Spaces:
Runtime error
Runtime error
import os | |
import gc | |
import numpy as np | |
import pandas as pd | |
from tqdm import tqdm | |
import random | |
import json | |
import torch | |
from torch import nn | |
#FIX | |
import config as CFG | |
from models import CLIPModel | |
from utils import AvgMeter, get_lr | |
from utils import get_datasets, build_loaders | |
def train_epoch(model, train_loader, optimizer, lr_scheduler, step): | |
""" | |
Performs one epoch of training. | |
Parameters: | |
----------- | |
model: PoemTextModel or CLIPModel | |
model to train | |
train_loader: torch.utils.data.DataLoader | |
dataloader to get batches from | |
optimizer: torch.optim.Optimizer | |
optimizer used for training | |
lr_scheduler: torch.optim.lr_scheduler.LRScheduler | |
scheduler used for training | |
step: str ("batch" or "epoch") | |
if "batch", lr_scheduler will step (update) for each batch of loader. | |
else lr_scheduler only steps and updates after finishing each epoch. | |
Returns: | |
-------- | |
loss_meter: AvgMeter | |
the class containing average loss of this epoch's training | |
""" | |
loss_meter = AvgMeter() # to track average of loss | |
tqdm_object = tqdm(train_loader, total=len(train_loader)) | |
for batch_cpu in tqdm_object: | |
# put batch data on device | |
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
if "image" in batch_cpu: | |
batch["image"] = batch_cpu["image"].to(CFG.device) | |
#get model's embeddings and calculate loss | |
poem_or_img_embeddings, text_embeddings = model(batch) | |
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) | |
# backpropagate and step | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if step == "batch": | |
lr_scheduler.step() | |
#update training info | |
count = batch["text"]["input_ids"].size(0) | |
loss_meter.update(loss.item(), count) | |
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer)) | |
# print('train loss: ', loss_meter.avg) | |
return loss_meter | |
def valid_epoch(model, valid_loader): | |
""" | |
Performs one epoch of validation. | |
Parameters: | |
----------- | |
model: PoemTextModel or CLIPModel | |
model to validate | |
valid_loader: torch.utils.data.DataLoader | |
dataloader to get batches from. | |
Returns: | |
-------- | |
loss_meter: AvgMeter | |
the class containing average loss of this epoch's validation | |
""" | |
loss_meter = AvgMeter() # to track average of loss | |
tqdm_object = tqdm(valid_loader, total=len(valid_loader)) | |
for batch_cpu in tqdm_object: | |
# put batch data on device | |
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
if "image" in batch_cpu: | |
batch["image"] = batch_cpu["image"].to(CFG.device) | |
#get model's embeddings and calculate loss | |
poem_or_img_embeddings, text_embeddings = model(batch) | |
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings) | |
#update validation info | |
count = batch["text"]["input_ids"].size(0) | |
loss_meter.update(loss.item(), count) | |
tqdm_object.set_postfix(valid_loss=loss_meter.avg) | |
# print('validation loss: ', loss_meter.avg) | |
return loss_meter | |
def test(model, test_dataset): | |
""" | |
Calculates accuracy on test set. | |
This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem. | |
Parameters: | |
----------- | |
model: PoemTextModel | |
model to test | |
test_dataset: list of dict | |
the list containing dict of data to perform test on (must have "text" and "poem" keys) | |
Returns: | |
-------- | |
accuracy: np.float | |
The accuracy of model on the test set given | |
""" | |
test_loader = build_loaders(test_dataset, mode="test") | |
accuracy = 0 | |
tqdm_object = tqdm(test_loader, total=len(test_loader)) | |
model.eval() | |
with torch.no_grad(): | |
for batch_cpu in tqdm_object: | |
# put batch data on device | |
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]} | |
if "image" in batch_cpu: | |
batch["image"] = batch_cpu["image"].to(CFG.device) | |
# get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text) | |
pred = model.predict(batch).cpu().numpy() | |
count = batch["text"]["input_ids"].size(0) | |
# since each text is associated with the poem with the same index as it, np.arange(count) is the real labels. | |
acc = np.sum(pred == np.arange(count)) | |
accuracy += acc | |
tqdm_object.set_postfix(accuracy=acc / count) | |
accuracy /= len(test_dataset) | |
return accuracy | |
def train(model, train_loader, valid_loader, epochs=CFG.epochs): | |
""" | |
Performs train and validation for (epochs) epochs. | |
Parameters: | |
----------- | |
model: PoemTextModel or CLIPModel | |
model to train | |
train_loader: torch.utils.data.DataLoader | |
train dataloader to get batches from | |
valid_loader: torch.utils.data.DataLoader | |
validation dataloader to get batches from | |
epochs: int, optional | |
the number of epochs to train | |
Returns: | |
-------- | |
model: PoemTextModel or CLIPModel | |
trained model | |
loss_history: dict | |
a dict containing train and validation average loss for each epoch. | |
""" | |
# Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config | |
optimizer = torch.optim.AdamW( | |
model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay | |
) | |
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor | |
) | |
# if step="batch", lr_scheduler will step (update) for each batch of loader. | |
# else lr_scheduler only steps and updates after finishing each epoch. (this case) | |
step = "epoch" | |
loss_history = {"train":[], "valid":[]} | |
# to keep track of best validation loss | |
best_loss = float('inf') | |
for epoch in range(CFG.epochs): | |
print(f"Epoch: {epoch + 1}") | |
# train for one epoch | |
model.train() | |
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step) | |
loss_history["train"].append(train_loss.avg) | |
# validate trained model | |
model.eval() | |
with torch.no_grad(): | |
valid_loss = valid_epoch(model, valid_loader) | |
loss_history["valid"].append(valid_loss.avg) | |
# if this epoch's avg validation loss is lower than best loss, save and keep this model. | |
if valid_loss.avg < best_loss: | |
best_loss = valid_loss.avg | |
model.save_current() | |
print("Saved Best Model!") | |
if step == "epoch": | |
lr_scheduler.step(valid_loss.avg) | |
return model, loss_history |