COVER / train_one_dataset.py
nanushio
+ [MAJOR] [ROOT] [CREATE] 1. fork repo from COVER github
feb2918
import torch
import cv2
import random
import os.path as osp
import argparse
from scipy.stats import spearmanr, pearsonr
from scipy.stats.stats import kendalltau as kendallr
import numpy as np
from time import time
from tqdm import tqdm
import pickle
import math
import wandb
import yaml
from collections import OrderedDict
from functools import reduce
from thop import profile
import copy
import cover.models as models
import cover.datasets as datasets
def train_test_split(dataset_path, ann_file, ratio=0.8, seed=42):
random.seed(seed)
print(seed)
video_infos = []
with open(ann_file, "r") as fin:
for line in fin.readlines():
line_split = line.strip().split(",")
filename, _, _, label = line_split
label = float(label)
filename = osp.join(dataset_path, filename)
video_infos.append(dict(filename=filename, label=label))
random.shuffle(video_infos)
return (
video_infos[: int(ratio * len(video_infos))],
video_infos[int(ratio * len(video_infos)) :],
)
def rank_loss(y_pred, y):
ranking_loss = torch.nn.functional.relu(
(y_pred - y_pred.t()) * torch.sign((y.t() - y))
)
scale = 1 + torch.max(ranking_loss)
return (
torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale
).float()
def gaussian(y, eps=1e-8):
return (y - y.mean()) / (y.std() + 1e-8)
def plcc_loss(y_pred, y):
sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False)
y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8)
sigma, m = torch.std_mean(y, unbiased=False)
y = (y - m) / (sigma + 1e-8)
loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4
rho = torch.mean(y_pred * y)
loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4
return ((loss0 + loss1) / 2).float()
def rescaled_l2_loss(y_pred, y):
y_pred_rs = (y_pred - y_pred.mean()) / y_pred.std()
y_rs = (y - y.mean()) / (y.std() + eps)
return torch.nn.functional.mse_loss(y_pred_rs, y_rs)
def rplcc_loss(y_pred, y, eps=1e-8):
## Literally (1 - PLCC) / 2
y_pred, y = gaussian(y_pred), gaussian(y)
cov = torch.sum(y_pred * y) / y_pred.shape[0]
# std = (torch.std(y_pred) + eps) * (torch.std(y) + eps)
return (1 - cov) / 2
def self_similarity_loss(f, f_hat, f_hat_detach=False):
if f_hat_detach:
f_hat = f_hat.detach()
return 1 - torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean()
def contrastive_similarity_loss(f, f_hat, f_hat_detach=False, eps=1e-8):
if f_hat_detach:
f_hat = f_hat.detach()
intra_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean()
cross_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=0).mean()
return (1 - intra_similarity) / (1 - cross_similarity + eps)
def rescale(pr, gt=None):
if gt is None:
pr = (pr - np.mean(pr)) / np.std(pr)
else:
pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt)
return pr
sample_types = ["semantic", "technical", "aesthetic"]
def finetune_epoch(
ft_loader,
model,
model_ema,
optimizer,
scheduler,
device,
epoch=-1,
need_upsampled=False,
need_feat=False,
need_fused=False,
need_separate_sup=True,
):
model.train()
for i, data in enumerate(tqdm(ft_loader, desc=f"Training in epoch {epoch}")):
optimizer.zero_grad()
video = {}
for key in sample_types:
if key in data:
video[key] = data[key].to(device)
y = data["gt_label"].float().detach().to(device).unsqueeze(-1)
scores = model(video, inference=False, reduce_scores=False)
if len(scores) > 1:
y_pred = reduce(lambda x, y: x + y, scores)
else:
y_pred = scores[0]
y_pred = y_pred.mean((-3, -2, -1))
frame_inds = data["frame_inds"]
loss = 0 # p_loss + 0.3 * r_loss
if need_separate_sup:
p_loss_a = plcc_loss(scores[0].mean((-3, -2, -1)), y)
p_loss_b = plcc_loss(scores[1].mean((-3, -2, -1)), y)
p_loss_c = plcc_loss(scores[2].mean((-3, -2, -1)), y)
r_loss_a = rank_loss(scores[0].mean((-3, -2, -1)), y)
r_loss_b = rank_loss(scores[1].mean((-3, -2, -1)), y)
r_loss_c = rank_loss(scores[2].mean((-3, -2, -1)), y)
loss += (
p_loss_a + p_loss_b + p_loss_c + 0.3 * r_loss_a + 0.3 * r_loss_b + 0.3 * r_loss_c
) # + 0.2 * o_loss
wandb.log(
{
"train/plcc_loss_a": p_loss_a.item(),
"train/plcc_loss_b": p_loss_b.item(),
"train/plcc_loss_c": p_loss_c.item(),
}
)
wandb.log(
{"train/total_loss": loss.item(),}
)
loss.backward()
optimizer.step()
scheduler.step()
# ft_loader.dataset.refresh_hypers()
if model_ema is not None:
model_params = dict(model.named_parameters())
model_ema_params = dict(model_ema.named_parameters())
for k in model_params.keys():
model_ema_params[k].data.mul_(0.999).add_(
model_params[k].data, alpha=1 - 0.999
)
model.eval()
def profile_inference(inf_set, model, device):
video = {}
data = inf_set[0]
for key in sample_types:
if key in data:
video[key] = data[key].to(device).unsqueeze(0)
with torch.no_grad():
flops, params = profile(model, (video,))
print(
f"The FLOps of the Variant is {flops/1e9:.1f}G, with Params {params/1e6:.2f}M."
)
def inference_set(
inf_loader,
model,
device,
best_,
save_model=False,
suffix="s",
save_name="divide",
save_type="head",
):
results = []
best_s, best_p, best_k, best_r = best_
for i, data in enumerate(tqdm(inf_loader, desc="Validating")):
result = dict()
video, video_up = {}, {}
for key in sample_types:
if key in data:
video[key] = data[key].to(device)
## Reshape into clips
b, c, t, h, w = video[key].shape
video[key] = (
video[key]
.reshape(
b, c, data["num_clips"][key], t // data["num_clips"][key], h, w
)
.permute(0, 2, 1, 3, 4, 5)
.reshape(
b * data["num_clips"][key], c, t // data["num_clips"][key], h, w
)
)
if key + "_up" in data:
video_up[key] = data[key + "_up"].to(device)
## Reshape into clips
b, c, t, h, w = video_up[key].shape
video_up[key] = (
video_up[key]
.reshape(b, c, data["num_clips"], t // data["num_clips"], h, w)
.permute(0, 2, 1, 3, 4, 5)
.reshape(b * data["num_clips"], c, t // data["num_clips"], h, w)
)
# .unsqueeze(0)
with torch.no_grad():
result["pr_labels"] = model(video, reduce_scores=True).cpu().numpy()
if len(list(video_up.keys())) > 0:
result["pr_labels_up"] = model(video_up).cpu().numpy()
result["gt_label"] = data["gt_label"].item()
del video, video_up
results.append(result)
## generate the demo video for video quality localization
gt_labels = [r["gt_label"] for r in results]
pr_labels = [np.mean(r["pr_labels"][:]) for r in results]
pr_labels = rescale(pr_labels, gt_labels)
s = spearmanr(gt_labels, pr_labels)[0]
p = pearsonr(gt_labels, pr_labels)[0]
k = kendallr(gt_labels, pr_labels)[0]
r = np.sqrt(((gt_labels - pr_labels) ** 2).mean())
wandb.log(
{
f"val_{suffix}/SRCC-{suffix}": s,
f"val_{suffix}/PLCC-{suffix}": p,
f"val_{suffix}/KRCC-{suffix}": k,
f"val_{suffix}/RMSE-{suffix}": r,
}
)
del results, result # , video, video_up
torch.cuda.empty_cache()
if s + p > best_s + best_p and save_model:
state_dict = model.state_dict()
if save_type == "head":
head_state_dict = OrderedDict()
for key, v in state_dict.items():
if "backbone" in key:
continue
else:
head_state_dict[key] = v
print("Following keys are saved :", head_state_dict.keys())
torch.save(
{"state_dict": head_state_dict, "validation_results": best_,},
f"pretrained_weights/{save_name}_{suffix}_finetuned.pth",
)
else:
torch.save(
{"state_dict": state_dict, "validation_results": best_,},
f"pretrained_weights/{save_name}_{suffix}_finetuned.pth",
)
best_s, best_p, best_k, best_r = (
max(best_s, s),
max(best_p, p),
max(best_k, k),
min(best_r, r),
)
wandb.log(
{
f"val_{suffix}/best_SRCC-{suffix}": best_s,
f"val_{suffix}/best_PLCC-{suffix}": best_p,
f"val_{suffix}/best_KRCC-{suffix}": best_k,
f"val_{suffix}/best_RMSE-{suffix}": best_r,
}
)
print(
f"For {len(inf_loader)} videos, \nthe accuracy of the model: [{suffix}] is as follows:\n SROCC: {s:.4f} best: {best_s:.4f} \n PLCC: {p:.4f} best: {best_p:.4f} \n KROCC: {k:.4f} best: {best_k:.4f} \n RMSE: {r:.4f} best: {best_r:.4f}."
)
return best_s, best_p, best_k, best_r
# torch.save(results, f'{args.save_dir}/results_{dataset.lower()}_s{32}*{32}_ens{args.famount}.pkl')
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"-o", "--opt", type=str, default="cover.yml", help="the option file"
)
parser.add_argument(
"-t", "--target_set", type=str, default="val-kv1k", help="target_set"
)
parser.add_argument('-n', "--name", type=str, default="COVER_TMP", help='model name to save checkpoint')
parser.add_argument('-uh', "--usehead", type=int, default=0, help='wheather to load header weight from checkpoint')
args = parser.parse_args()
with open(args.opt, "r") as f:
opt = yaml.safe_load(f)
print(opt)
## adaptively choose the device
device = "cuda" if torch.cuda.is_available() else "cpu"
## defining model and loading checkpoint
bests_ = []
if opt.get("split_seed", -1) > 0:
num_splits = 10
else:
num_splits = 1
print(opt["split_seed"])
for split in range(10):
model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device)
if opt.get("split_seed", -1) > 0:
opt["data"]["train"] = copy.deepcopy(opt["data"][args.target_set])
opt["data"]["eval"] = copy.deepcopy(opt["data"][args.target_set])
split_duo = train_test_split(
opt["data"][args.target_set]["args"]["data_prefix"],
opt["data"][args.target_set]["args"]["anno_file"],
seed=opt["split_seed"] * (split + 1),
)
(
opt["data"]["train"]["args"]["anno_file"],
opt["data"]["eval"]["args"]["anno_file"],
) = split_duo
opt["data"]["train"]["args"]["sample_types"]["technical"]["num_clips"] = 1
train_datasets = {}
for key in opt["data"]:
if key.startswith("train"):
train_dataset = getattr(datasets, opt["data"][key]["type"])(
opt["data"][key]["args"]
)
train_datasets[key] = train_dataset
print(len(train_dataset.video_infos))
train_loaders = {}
for key, train_dataset in train_datasets.items():
train_loaders[key] = torch.utils.data.DataLoader(
train_dataset,
batch_size=opt["batch_size"],
num_workers=opt["num_workers"],
shuffle=True,
)
val_datasets = {}
for key in opt["data"]:
if key.startswith("eval"):
val_dataset = getattr(datasets, opt["data"][key]["type"])(
opt["data"][key]["args"]
)
print(len(val_dataset.video_infos))
val_datasets[key] = val_dataset
val_loaders = {}
for key, val_dataset in val_datasets.items():
val_loaders[key] = torch.utils.data.DataLoader(
val_dataset,
batch_size=1,
num_workers=opt["num_workers"],
pin_memory=True,
)
run = wandb.init(
project=opt["wandb"]["project_name"],
name=opt["name"] + f"_target_{args.target_set}_split_{split}"
if num_splits > 1
else opt["name"],
reinit=True,
settings=wandb.Settings(start_method="thread"),
)
state_dict = torch.load(opt["test_load_path"], map_location=device)
# Load fine_tuned header from checkpoint
if args.usehead:
state_dict_head = torch.load(opt["test_load_header_path"], map_location=device)
for key in state_dict_head['state_dict'].keys():
state_dict[key] = state_dict_head['state_dict'][key]
# Allowing empty head weight
model.load_state_dict(state_dict, strict=False)
if opt["ema"]:
from copy import deepcopy
model_ema = deepcopy(model)
else:
model_ema = None
# profile_inference(val_dataset, model, device)
# finetune the model
param_groups = []
for key, value in dict(model.named_children()).items():
if "backbone" in key:
param_groups += [
{
"params": value.parameters(),
"lr": opt["optimizer"]["lr"]
* opt["optimizer"]["backbone_lr_mult"],
}
]
else:
param_groups += [
{"params": value.parameters(), "lr": opt["optimizer"]["lr"]}
]
optimizer = torch.optim.AdamW(
lr=opt["optimizer"]["lr"],
params=param_groups,
weight_decay=opt["optimizer"]["wd"],
)
warmup_iter = 0
for train_loader in train_loaders.values():
warmup_iter += int(opt["warmup_epochs"] * len(train_loader))
max_iter = int((opt["num_epochs"] + opt["l_num_epochs"]) * len(train_loader))
lr_lambda = (
lambda cur_iter: cur_iter / warmup_iter
if cur_iter <= warmup_iter
else 0.5 * (1 + math.cos(math.pi * (cur_iter - warmup_iter) / max_iter))
)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda,)
bests = {}
bests_n = {}
for key in val_loaders:
bests[key] = -1, -1, -1, 1000
bests_n[key] = -1, -1, -1, 1000
for key, value in dict(model.named_children()).items():
if "backbone" in key:
for param in value.parameters():
param.requires_grad = False
for epoch in range(opt["l_num_epochs"]):
print(f"Linear Epoch {epoch}:")
for key, train_loader in train_loaders.items():
finetune_epoch(
train_loader,
model,
model_ema,
optimizer,
scheduler,
device,
epoch,
opt.get("need_upsampled", False),
opt.get("need_feat", False),
opt.get("need_fused", False),
)
for key in val_loaders:
bests[key] = inference_set(
val_loaders[key],
model_ema if model_ema is not None else model,
device,
bests[key],
save_model=opt["save_model"],
save_name=args.name + "_head_" + args.target_set + f"_{split}",
suffix=key + "_s",
)
if model_ema is not None:
bests_n[key] = inference_set(
val_loaders[key],
model,
device,
bests_n[key],
save_model=opt["save_model"],
save_name=args.name
+ "_head_"
+ args.target_set
+ f"_{split}",
suffix=key + "_n",
)
else:
bests_n[key] = bests[key]
if opt["l_num_epochs"] >= 0:
for key in val_loaders:
print(
f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos,
the best validation accuracy of the model-s is as follows:
SROCC: {bests[key][0]:.4f}
PLCC: {bests[key][1]:.4f}
KROCC: {bests[key][2]:.4f}
RMSE: {bests[key][3]:.4f}."""
)
print(
f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos,
the best validation accuracy of the model-n is as follows:
SROCC: {bests_n[key][0]:.4f}
PLCC: {bests_n[key][1]:.4f}
KROCC: {bests_n[key][2]:.4f}
RMSE: {bests_n[key][3]:.4f}."""
)
for key, value in dict(model.named_children()).items():
if "backbone" in key:
for param in value.parameters():
param.requires_grad = True
for epoch in range(opt["num_epochs"]):
print(f"End-to-end Epoch {epoch}:")
for key, train_loader in train_loaders.items():
finetune_epoch(
train_loader,
model,
model_ema,
optimizer,
scheduler,
device,
epoch,
opt.get("need_upsampled", False),
opt.get("need_feat", False),
opt.get("need_fused", False),
)
for key in val_loaders:
bests[key] = inference_set(
val_loaders[key],
model_ema if model_ema is not None else model,
device,
bests[key],
save_model=opt["save_model"],
save_name=args.name + "_head_" + args.target_set + f"_{split}",
suffix=key + "_s",
save_type="full",
)
if model_ema is not None:
bests_n[key] = inference_set(
val_loaders[key],
model,
device,
bests_n[key],
save_model=opt["save_model"],
save_name=args.name
+ "_head_"
+ args.target_set
+ f"_{split}",
suffix=key + "_n",
save_type="full",
)
else:
bests_n[key] = bests[key]
if opt["num_epochs"] >= 0:
for key in val_loaders:
print(
f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos,
the best validation accuracy of the model-s is as follows:
SROCC: {bests[key][0]:.4f}
PLCC: {bests[key][1]:.4f}
KROCC: {bests[key][2]:.4f}
RMSE: {bests[key][3]:.4f}."""
)
print(
f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos,
the best validation accuracy of the model-n is as follows:
SROCC: {bests_n[key][0]:.4f}
PLCC: {bests_n[key][1]:.4f}
KROCC: {bests_n[key][2]:.4f}
RMSE: {bests_n[key][3]:.4f}."""
)
for key, value in dict(model.named_children()).items():
if "backbone" in key:
for param in value.parameters():
param.requires_grad = True
run.finish()
if __name__ == "__main__":
main()