mojtaba-nafez's picture
add initial files to deploy
2fa2727
import os
import gc
import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import json
import torch
from torch import nn
#FIX
import config as CFG
from models import CLIPModel
from utils import AvgMeter, get_lr
from utils import get_datasets, build_loaders
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
"""
Performs one epoch of training.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
dataloader to get batches from
optimizer: torch.optim.Optimizer
optimizer used for training
lr_scheduler: torch.optim.lr_scheduler.LRScheduler
scheduler used for training
step: str ("batch" or "epoch")
if "batch", lr_scheduler will step (update) for each batch of loader.
else lr_scheduler only steps and updates after finishing each epoch.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's training
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(train_loader, total=len(train_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
# backpropagate and step
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step == "batch":
lr_scheduler.step()
#update training info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, lr=get_lr(optimizer))
# print('train loss: ', loss_meter.avg)
return loss_meter
def valid_epoch(model, valid_loader):
"""
Performs one epoch of validation.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to validate
valid_loader: torch.utils.data.DataLoader
dataloader to get batches from.
Returns:
--------
loss_meter: AvgMeter
the class containing average loss of this epoch's validation
"""
loss_meter = AvgMeter() # to track average of loss
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
#get model's embeddings and calculate loss
poem_or_img_embeddings, text_embeddings = model(batch)
loss = model.calculate_loss(poem_or_img_embeddings, text_embeddings)
#update validation info
count = batch["text"]["input_ids"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(valid_loss=loss_meter.avg)
# print('validation loss: ', loss_meter.avg)
return loss_meter
def test(model, test_dataset):
"""
Calculates accuracy on test set.
This method is used for the PoemTextModel, since the other model (CLIPModel) does not have a test set containing pairs of image-poem.
Parameters:
-----------
model: PoemTextModel
model to test
test_dataset: list of dict
the list containing dict of data to perform test on (must have "text" and "poem" keys)
Returns:
--------
accuracy: np.float
The accuracy of model on the test set given
"""
test_loader = build_loaders(test_dataset, mode="test")
accuracy = 0
tqdm_object = tqdm(test_loader, total=len(test_loader))
model.eval()
with torch.no_grad():
for batch_cpu in tqdm_object:
# put batch data on device
batch = {k: {dict_k: dict_v.to(CFG.device) for dict_k, dict_v in v.items()} for k, v in batch_cpu.items() if not k in ["id", "image"]}
if "image" in batch_cpu:
batch["image"] = batch_cpu["image"].to(CFG.device)
# get model's prediction for each text (a numpy array of index/labels showing which poem belongs to which text)
pred = model.predict(batch).cpu().numpy()
count = batch["text"]["input_ids"].size(0)
# since each text is associated with the poem with the same index as it, np.arange(count) is the real labels.
acc = np.sum(pred == np.arange(count))
accuracy += acc
tqdm_object.set_postfix(accuracy=acc / count)
accuracy /= len(test_dataset)
return accuracy
def train(model, train_loader, valid_loader, epochs=CFG.epochs):
"""
Performs train and validation for (epochs) epochs.
Parameters:
-----------
model: PoemTextModel or CLIPModel
model to train
train_loader: torch.utils.data.DataLoader
train dataloader to get batches from
valid_loader: torch.utils.data.DataLoader
validation dataloader to get batches from
epochs: int, optional
the number of epochs to train
Returns:
--------
model: PoemTextModel or CLIPModel
trained model
loss_history: dict
a dict containing train and validation average loss for each epoch.
"""
# Using AdamW optimizer and ReduceLROnPlateau lr-scheduler with settings from config
optimizer = torch.optim.AdamW(
model.parameters(), lr=CFG.lr, weight_decay=CFG.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=CFG.patience, factor=CFG.factor
)
# if step="batch", lr_scheduler will step (update) for each batch of loader.
# else lr_scheduler only steps and updates after finishing each epoch. (this case)
step = "epoch"
loss_history = {"train":[], "valid":[]}
# to keep track of best validation loss
best_loss = float('inf')
for epoch in range(CFG.epochs):
print(f"Epoch: {epoch + 1}")
# train for one epoch
model.train()
train_loss = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
loss_history["train"].append(train_loss.avg)
# validate trained model
model.eval()
with torch.no_grad():
valid_loss = valid_epoch(model, valid_loader)
loss_history["valid"].append(valid_loss.avg)
# if this epoch's avg validation loss is lower than best loss, save and keep this model.
if valid_loss.avg < best_loss:
best_loss = valid_loss.avg
model.save_current()
print("Saved Best Model!")
if step == "epoch":
lr_scheduler.step(valid_loss.avg)
return model, loss_history