|
import torch |
|
|
|
import cv2 |
|
import random |
|
import os.path as osp |
|
import argparse |
|
from scipy.stats import spearmanr, pearsonr |
|
from scipy.stats.stats import kendalltau as kendallr |
|
import numpy as np |
|
from time import time |
|
from tqdm import tqdm |
|
import pickle |
|
import math |
|
import wandb |
|
import yaml |
|
from collections import OrderedDict |
|
|
|
from functools import reduce |
|
from thop import profile |
|
import copy |
|
|
|
|
|
import cover.models as models |
|
import cover.datasets as datasets |
|
|
|
|
|
def train_test_split(dataset_path, ann_file, ratio=0.8, seed=42): |
|
random.seed(seed) |
|
print(seed) |
|
video_infos = [] |
|
with open(ann_file, "r") as fin: |
|
for line in fin.readlines(): |
|
line_split = line.strip().split(",") |
|
filename, _, _, label = line_split |
|
label = float(label) |
|
filename = osp.join(dataset_path, filename) |
|
video_infos.append(dict(filename=filename, label=label)) |
|
random.shuffle(video_infos) |
|
return ( |
|
video_infos[: int(ratio * len(video_infos))], |
|
video_infos[int(ratio * len(video_infos)) :], |
|
) |
|
|
|
|
|
def rank_loss(y_pred, y): |
|
ranking_loss = torch.nn.functional.relu( |
|
(y_pred - y_pred.t()) * torch.sign((y.t() - y)) |
|
) |
|
scale = 1 + torch.max(ranking_loss) |
|
return ( |
|
torch.sum(ranking_loss) / y_pred.shape[0] / (y_pred.shape[0] - 1) / scale |
|
).float() |
|
|
|
|
|
def gaussian(y, eps=1e-8): |
|
return (y - y.mean()) / (y.std() + 1e-8) |
|
|
|
|
|
def plcc_loss(y_pred, y): |
|
sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False) |
|
y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8) |
|
sigma, m = torch.std_mean(y, unbiased=False) |
|
y = (y - m) / (sigma + 1e-8) |
|
loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4 |
|
rho = torch.mean(y_pred * y) |
|
loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4 |
|
return ((loss0 + loss1) / 2).float() |
|
|
|
|
|
def rescaled_l2_loss(y_pred, y): |
|
y_pred_rs = (y_pred - y_pred.mean()) / y_pred.std() |
|
y_rs = (y - y.mean()) / (y.std() + eps) |
|
return torch.nn.functional.mse_loss(y_pred_rs, y_rs) |
|
|
|
|
|
def rplcc_loss(y_pred, y, eps=1e-8): |
|
|
|
y_pred, y = gaussian(y_pred), gaussian(y) |
|
cov = torch.sum(y_pred * y) / y_pred.shape[0] |
|
|
|
return (1 - cov) / 2 |
|
|
|
|
|
def self_similarity_loss(f, f_hat, f_hat_detach=False): |
|
if f_hat_detach: |
|
f_hat = f_hat.detach() |
|
return 1 - torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean() |
|
|
|
|
|
def contrastive_similarity_loss(f, f_hat, f_hat_detach=False, eps=1e-8): |
|
if f_hat_detach: |
|
f_hat = f_hat.detach() |
|
intra_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=1).mean() |
|
cross_similarity = torch.nn.functional.cosine_similarity(f, f_hat, dim=0).mean() |
|
return (1 - intra_similarity) / (1 - cross_similarity + eps) |
|
|
|
|
|
def rescale(pr, gt=None): |
|
if gt is None: |
|
pr = (pr - np.mean(pr)) / np.std(pr) |
|
else: |
|
pr = ((pr - np.mean(pr)) / np.std(pr)) * np.std(gt) + np.mean(gt) |
|
return pr |
|
|
|
sample_types = ["semantic", "technical", "aesthetic"] |
|
|
|
|
|
def finetune_epoch( |
|
ft_loader, |
|
model, |
|
model_ema, |
|
optimizer, |
|
scheduler, |
|
device, |
|
epoch=-1, |
|
need_upsampled=False, |
|
need_feat=False, |
|
need_fused=False, |
|
need_separate_sup=True, |
|
): |
|
model.train() |
|
for i, data in enumerate(tqdm(ft_loader, desc=f"Training in epoch {epoch}")): |
|
optimizer.zero_grad() |
|
video = {} |
|
for key in sample_types: |
|
if key in data: |
|
video[key] = data[key].to(device) |
|
|
|
y = data["gt_label"].float().detach().to(device).unsqueeze(-1) |
|
|
|
scores = model(video, inference=False, reduce_scores=False) |
|
if len(scores) > 1: |
|
y_pred = reduce(lambda x, y: x + y, scores) |
|
else: |
|
y_pred = scores[0] |
|
y_pred = y_pred.mean((-3, -2, -1)) |
|
|
|
frame_inds = data["frame_inds"] |
|
|
|
|
|
loss = 0 |
|
|
|
if need_separate_sup: |
|
p_loss_a = plcc_loss(scores[0].mean((-3, -2, -1)), y) |
|
p_loss_b = plcc_loss(scores[1].mean((-3, -2, -1)), y) |
|
p_loss_c = plcc_loss(scores[2].mean((-3, -2, -1)), y) |
|
r_loss_a = rank_loss(scores[0].mean((-3, -2, -1)), y) |
|
r_loss_b = rank_loss(scores[1].mean((-3, -2, -1)), y) |
|
r_loss_c = rank_loss(scores[2].mean((-3, -2, -1)), y) |
|
loss += ( |
|
p_loss_a + p_loss_b + p_loss_c + 0.3 * r_loss_a + 0.3 * r_loss_b + 0.3 * r_loss_c |
|
) |
|
wandb.log( |
|
{ |
|
"train/plcc_loss_a": p_loss_a.item(), |
|
"train/plcc_loss_b": p_loss_b.item(), |
|
"train/plcc_loss_c": p_loss_c.item(), |
|
} |
|
) |
|
|
|
wandb.log( |
|
{"train/total_loss": loss.item(),} |
|
) |
|
|
|
loss.backward() |
|
optimizer.step() |
|
scheduler.step() |
|
|
|
|
|
|
|
if model_ema is not None: |
|
model_params = dict(model.named_parameters()) |
|
model_ema_params = dict(model_ema.named_parameters()) |
|
for k in model_params.keys(): |
|
model_ema_params[k].data.mul_(0.999).add_( |
|
model_params[k].data, alpha=1 - 0.999 |
|
) |
|
model.eval() |
|
|
|
|
|
def profile_inference(inf_set, model, device): |
|
video = {} |
|
data = inf_set[0] |
|
for key in sample_types: |
|
if key in data: |
|
video[key] = data[key].to(device).unsqueeze(0) |
|
with torch.no_grad(): |
|
|
|
flops, params = profile(model, (video,)) |
|
print( |
|
f"The FLOps of the Variant is {flops/1e9:.1f}G, with Params {params/1e6:.2f}M." |
|
) |
|
|
|
|
|
def inference_set( |
|
inf_loader, |
|
model, |
|
device, |
|
best_, |
|
save_model=False, |
|
suffix="s", |
|
save_name="divide", |
|
save_type="head", |
|
): |
|
|
|
results = [] |
|
|
|
best_s, best_p, best_k, best_r = best_ |
|
|
|
for i, data in enumerate(tqdm(inf_loader, desc="Validating")): |
|
result = dict() |
|
video, video_up = {}, {} |
|
for key in sample_types: |
|
if key in data: |
|
video[key] = data[key].to(device) |
|
|
|
b, c, t, h, w = video[key].shape |
|
video[key] = ( |
|
video[key] |
|
.reshape( |
|
b, c, data["num_clips"][key], t // data["num_clips"][key], h, w |
|
) |
|
.permute(0, 2, 1, 3, 4, 5) |
|
.reshape( |
|
b * data["num_clips"][key], c, t // data["num_clips"][key], h, w |
|
) |
|
) |
|
if key + "_up" in data: |
|
video_up[key] = data[key + "_up"].to(device) |
|
|
|
b, c, t, h, w = video_up[key].shape |
|
video_up[key] = ( |
|
video_up[key] |
|
.reshape(b, c, data["num_clips"], t // data["num_clips"], h, w) |
|
.permute(0, 2, 1, 3, 4, 5) |
|
.reshape(b * data["num_clips"], c, t // data["num_clips"], h, w) |
|
) |
|
|
|
with torch.no_grad(): |
|
result["pr_labels"] = model(video, reduce_scores=True).cpu().numpy() |
|
if len(list(video_up.keys())) > 0: |
|
result["pr_labels_up"] = model(video_up).cpu().numpy() |
|
|
|
result["gt_label"] = data["gt_label"].item() |
|
del video, video_up |
|
results.append(result) |
|
|
|
|
|
gt_labels = [r["gt_label"] for r in results] |
|
pr_labels = [np.mean(r["pr_labels"][:]) for r in results] |
|
pr_labels = rescale(pr_labels, gt_labels) |
|
|
|
s = spearmanr(gt_labels, pr_labels)[0] |
|
p = pearsonr(gt_labels, pr_labels)[0] |
|
k = kendallr(gt_labels, pr_labels)[0] |
|
r = np.sqrt(((gt_labels - pr_labels) ** 2).mean()) |
|
|
|
wandb.log( |
|
{ |
|
f"val_{suffix}/SRCC-{suffix}": s, |
|
f"val_{suffix}/PLCC-{suffix}": p, |
|
f"val_{suffix}/KRCC-{suffix}": k, |
|
f"val_{suffix}/RMSE-{suffix}": r, |
|
} |
|
) |
|
|
|
del results, result |
|
torch.cuda.empty_cache() |
|
|
|
if s + p > best_s + best_p and save_model: |
|
state_dict = model.state_dict() |
|
|
|
if save_type == "head": |
|
head_state_dict = OrderedDict() |
|
for key, v in state_dict.items(): |
|
if "backbone" in key: |
|
continue |
|
else: |
|
head_state_dict[key] = v |
|
print("Following keys are saved :", head_state_dict.keys()) |
|
torch.save( |
|
{"state_dict": head_state_dict, "validation_results": best_,}, |
|
f"pretrained_weights/{save_name}_{suffix}_finetuned.pth", |
|
) |
|
else: |
|
torch.save( |
|
{"state_dict": state_dict, "validation_results": best_,}, |
|
f"pretrained_weights/{save_name}_{suffix}_finetuned.pth", |
|
) |
|
|
|
best_s, best_p, best_k, best_r = ( |
|
max(best_s, s), |
|
max(best_p, p), |
|
max(best_k, k), |
|
min(best_r, r), |
|
) |
|
|
|
wandb.log( |
|
{ |
|
f"val_{suffix}/best_SRCC-{suffix}": best_s, |
|
f"val_{suffix}/best_PLCC-{suffix}": best_p, |
|
f"val_{suffix}/best_KRCC-{suffix}": best_k, |
|
f"val_{suffix}/best_RMSE-{suffix}": best_r, |
|
} |
|
) |
|
|
|
print( |
|
f"For {len(inf_loader)} videos, \nthe accuracy of the model: [{suffix}] is as follows:\n SROCC: {s:.4f} best: {best_s:.4f} \n PLCC: {p:.4f} best: {best_p:.4f} \n KROCC: {k:.4f} best: {best_k:.4f} \n RMSE: {r:.4f} best: {best_r:.4f}." |
|
) |
|
|
|
return best_s, best_p, best_k, best_r |
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-o", "--opt", type=str, default="cover.yml", help="the option file" |
|
) |
|
|
|
parser.add_argument( |
|
"-t", "--target_set", type=str, default="val-kv1k", help="target_set" |
|
) |
|
|
|
parser.add_argument('-n', "--name", type=str, default="COVER_TMP", help='model name to save checkpoint') |
|
parser.add_argument('-uh', "--usehead", type=int, default=0, help='wheather to load header weight from checkpoint') |
|
|
|
args = parser.parse_args() |
|
with open(args.opt, "r") as f: |
|
opt = yaml.safe_load(f) |
|
print(opt) |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
bests_ = [] |
|
|
|
if opt.get("split_seed", -1) > 0: |
|
num_splits = 10 |
|
else: |
|
num_splits = 1 |
|
|
|
print(opt["split_seed"]) |
|
|
|
for split in range(10): |
|
model = getattr(models, opt["model"]["type"])(**opt["model"]["args"]).to(device) |
|
if opt.get("split_seed", -1) > 0: |
|
opt["data"]["train"] = copy.deepcopy(opt["data"][args.target_set]) |
|
opt["data"]["eval"] = copy.deepcopy(opt["data"][args.target_set]) |
|
|
|
split_duo = train_test_split( |
|
opt["data"][args.target_set]["args"]["data_prefix"], |
|
opt["data"][args.target_set]["args"]["anno_file"], |
|
seed=opt["split_seed"] * (split + 1), |
|
) |
|
( |
|
opt["data"]["train"]["args"]["anno_file"], |
|
opt["data"]["eval"]["args"]["anno_file"], |
|
) = split_duo |
|
opt["data"]["train"]["args"]["sample_types"]["technical"]["num_clips"] = 1 |
|
|
|
train_datasets = {} |
|
for key in opt["data"]: |
|
if key.startswith("train"): |
|
train_dataset = getattr(datasets, opt["data"][key]["type"])( |
|
opt["data"][key]["args"] |
|
) |
|
train_datasets[key] = train_dataset |
|
print(len(train_dataset.video_infos)) |
|
|
|
train_loaders = {} |
|
for key, train_dataset in train_datasets.items(): |
|
train_loaders[key] = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=opt["batch_size"], |
|
num_workers=opt["num_workers"], |
|
shuffle=True, |
|
) |
|
|
|
val_datasets = {} |
|
for key in opt["data"]: |
|
if key.startswith("eval"): |
|
val_dataset = getattr(datasets, opt["data"][key]["type"])( |
|
opt["data"][key]["args"] |
|
) |
|
print(len(val_dataset.video_infos)) |
|
val_datasets[key] = val_dataset |
|
|
|
val_loaders = {} |
|
for key, val_dataset in val_datasets.items(): |
|
val_loaders[key] = torch.utils.data.DataLoader( |
|
val_dataset, |
|
batch_size=1, |
|
num_workers=opt["num_workers"], |
|
pin_memory=True, |
|
) |
|
|
|
run = wandb.init( |
|
project=opt["wandb"]["project_name"], |
|
name=opt["name"] + f"_target_{args.target_set}_split_{split}" |
|
if num_splits > 1 |
|
else opt["name"], |
|
reinit=True, |
|
settings=wandb.Settings(start_method="thread"), |
|
) |
|
|
|
state_dict = torch.load(opt["test_load_path"], map_location=device) |
|
|
|
|
|
if args.usehead: |
|
state_dict_head = torch.load(opt["test_load_header_path"], map_location=device) |
|
for key in state_dict_head['state_dict'].keys(): |
|
state_dict[key] = state_dict_head['state_dict'][key] |
|
|
|
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
if opt["ema"]: |
|
from copy import deepcopy |
|
|
|
model_ema = deepcopy(model) |
|
else: |
|
model_ema = None |
|
|
|
|
|
|
|
|
|
|
|
param_groups = [] |
|
|
|
for key, value in dict(model.named_children()).items(): |
|
if "backbone" in key: |
|
param_groups += [ |
|
{ |
|
"params": value.parameters(), |
|
"lr": opt["optimizer"]["lr"] |
|
* opt["optimizer"]["backbone_lr_mult"], |
|
} |
|
] |
|
else: |
|
param_groups += [ |
|
{"params": value.parameters(), "lr": opt["optimizer"]["lr"]} |
|
] |
|
|
|
optimizer = torch.optim.AdamW( |
|
lr=opt["optimizer"]["lr"], |
|
params=param_groups, |
|
weight_decay=opt["optimizer"]["wd"], |
|
) |
|
warmup_iter = 0 |
|
for train_loader in train_loaders.values(): |
|
warmup_iter += int(opt["warmup_epochs"] * len(train_loader)) |
|
max_iter = int((opt["num_epochs"] + opt["l_num_epochs"]) * len(train_loader)) |
|
lr_lambda = ( |
|
lambda cur_iter: cur_iter / warmup_iter |
|
if cur_iter <= warmup_iter |
|
else 0.5 * (1 + math.cos(math.pi * (cur_iter - warmup_iter) / max_iter)) |
|
) |
|
|
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda,) |
|
|
|
bests = {} |
|
bests_n = {} |
|
for key in val_loaders: |
|
bests[key] = -1, -1, -1, 1000 |
|
bests_n[key] = -1, -1, -1, 1000 |
|
|
|
for key, value in dict(model.named_children()).items(): |
|
if "backbone" in key: |
|
for param in value.parameters(): |
|
param.requires_grad = False |
|
|
|
for epoch in range(opt["l_num_epochs"]): |
|
print(f"Linear Epoch {epoch}:") |
|
for key, train_loader in train_loaders.items(): |
|
finetune_epoch( |
|
train_loader, |
|
model, |
|
model_ema, |
|
optimizer, |
|
scheduler, |
|
device, |
|
epoch, |
|
opt.get("need_upsampled", False), |
|
opt.get("need_feat", False), |
|
opt.get("need_fused", False), |
|
) |
|
for key in val_loaders: |
|
bests[key] = inference_set( |
|
val_loaders[key], |
|
model_ema if model_ema is not None else model, |
|
device, |
|
bests[key], |
|
save_model=opt["save_model"], |
|
save_name=args.name + "_head_" + args.target_set + f"_{split}", |
|
suffix=key + "_s", |
|
) |
|
if model_ema is not None: |
|
bests_n[key] = inference_set( |
|
val_loaders[key], |
|
model, |
|
device, |
|
bests_n[key], |
|
save_model=opt["save_model"], |
|
save_name=args.name |
|
+ "_head_" |
|
+ args.target_set |
|
+ f"_{split}", |
|
suffix=key + "_n", |
|
) |
|
else: |
|
bests_n[key] = bests[key] |
|
|
|
if opt["l_num_epochs"] >= 0: |
|
for key in val_loaders: |
|
print( |
|
f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos, |
|
the best validation accuracy of the model-s is as follows: |
|
SROCC: {bests[key][0]:.4f} |
|
PLCC: {bests[key][1]:.4f} |
|
KROCC: {bests[key][2]:.4f} |
|
RMSE: {bests[key][3]:.4f}.""" |
|
) |
|
|
|
print( |
|
f"""For the linear transfer process on {key} with {len(val_loaders[key])} videos, |
|
the best validation accuracy of the model-n is as follows: |
|
SROCC: {bests_n[key][0]:.4f} |
|
PLCC: {bests_n[key][1]:.4f} |
|
KROCC: {bests_n[key][2]:.4f} |
|
RMSE: {bests_n[key][3]:.4f}.""" |
|
) |
|
|
|
for key, value in dict(model.named_children()).items(): |
|
if "backbone" in key: |
|
for param in value.parameters(): |
|
param.requires_grad = True |
|
|
|
for epoch in range(opt["num_epochs"]): |
|
print(f"End-to-end Epoch {epoch}:") |
|
for key, train_loader in train_loaders.items(): |
|
finetune_epoch( |
|
train_loader, |
|
model, |
|
model_ema, |
|
optimizer, |
|
scheduler, |
|
device, |
|
epoch, |
|
opt.get("need_upsampled", False), |
|
opt.get("need_feat", False), |
|
opt.get("need_fused", False), |
|
) |
|
for key in val_loaders: |
|
bests[key] = inference_set( |
|
val_loaders[key], |
|
model_ema if model_ema is not None else model, |
|
device, |
|
bests[key], |
|
save_model=opt["save_model"], |
|
save_name=args.name + "_head_" + args.target_set + f"_{split}", |
|
suffix=key + "_s", |
|
save_type="full", |
|
) |
|
if model_ema is not None: |
|
bests_n[key] = inference_set( |
|
val_loaders[key], |
|
model, |
|
device, |
|
bests_n[key], |
|
save_model=opt["save_model"], |
|
save_name=args.name |
|
+ "_head_" |
|
+ args.target_set |
|
+ f"_{split}", |
|
suffix=key + "_n", |
|
save_type="full", |
|
) |
|
else: |
|
bests_n[key] = bests[key] |
|
|
|
if opt["num_epochs"] >= 0: |
|
for key in val_loaders: |
|
print( |
|
f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos, |
|
the best validation accuracy of the model-s is as follows: |
|
SROCC: {bests[key][0]:.4f} |
|
PLCC: {bests[key][1]:.4f} |
|
KROCC: {bests[key][2]:.4f} |
|
RMSE: {bests[key][3]:.4f}.""" |
|
) |
|
|
|
print( |
|
f"""For the end-to-end transfer process on {key} with {len(val_loaders[key])} videos, |
|
the best validation accuracy of the model-n is as follows: |
|
SROCC: {bests_n[key][0]:.4f} |
|
PLCC: {bests_n[key][1]:.4f} |
|
KROCC: {bests_n[key][2]:.4f} |
|
RMSE: {bests_n[key][3]:.4f}.""" |
|
) |
|
|
|
for key, value in dict(model.named_children()).items(): |
|
if "backbone" in key: |
|
for param in value.parameters(): |
|
param.requires_grad = True |
|
|
|
run.finish() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|