|
import torch |
|
|
|
import argparse |
|
import pickle as pkl |
|
|
|
import decord |
|
from decord import VideoReader |
|
import numpy as np |
|
import yaml |
|
|
|
from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition |
|
from cover.models import COVER |
|
|
|
mean, std = ( |
|
torch.FloatTensor([123.675, 116.28, 103.53]), |
|
torch.FloatTensor([58.395, 57.12, 57.375]), |
|
) |
|
|
|
mean_clip, std_clip = ( |
|
torch.FloatTensor([122.77, 116.75, 104.09]), |
|
torch.FloatTensor([68.50, 66.63, 70.32]) |
|
) |
|
|
|
def fuse_results(results: list): |
|
x = (results[0] + results[1] + results[2]) |
|
return { |
|
"semantic" : results[0], |
|
"technical": results[1], |
|
"aesthetic": results[2], |
|
"overall" : x, |
|
} |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("-o", "--opt" , type=str, default="./cover.yml", help="the option file") |
|
parser.add_argument("--video_path", type=str, default="./demo/video_1.mp4" , help='output file to store predict mos value') |
|
args = parser.parse_args() |
|
return args |
|
|
|
if __name__ == "__main__": |
|
|
|
args = parse_args() |
|
|
|
""" |
|
BASIC SETTINGS |
|
""" |
|
torch.cuda.current_device() |
|
torch.cuda.empty_cache() |
|
torch.backends.cudnn.benchmark = True |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
with open(args.opt, "r") as f: |
|
opt = yaml.safe_load(f) |
|
|
|
dopt = opt["data"]["val-ytugc"]["args"] |
|
temporal_samplers = {} |
|
for stype, sopt in dopt["sample_types"].items(): |
|
temporal_samplers[stype] = UnifiedFrameSampler( |
|
sopt["clip_len"] // sopt["t_frag"], |
|
sopt["t_frag"], |
|
sopt["frame_interval"], |
|
sopt["num_clips"], |
|
) |
|
|
|
""" |
|
LOAD MODEL |
|
""" |
|
evaluator = COVER(**opt["model"]["args"]).to(device) |
|
state_dict = torch.load(opt["test_load_path"], map_location=device) |
|
|
|
|
|
|
|
evaluator.load_state_dict(state_dict['state_dict'], strict=False) |
|
|
|
""" |
|
TESTING |
|
""" |
|
views, _ = spatial_temporal_view_decomposition( |
|
args.video_path, dopt["sample_types"], temporal_samplers |
|
) |
|
|
|
for k, v in views.items(): |
|
num_clips = dopt["sample_types"][k].get("num_clips", 1) |
|
if k == 'technical' or k == 'aesthetic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean) / std) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
elif k == 'semantic': |
|
views[k] = ( |
|
((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) |
|
.permute(3, 0, 1, 2) |
|
.reshape(v.shape[0], num_clips, -1, *v.shape[2:]) |
|
.transpose(0, 1) |
|
.to(device) |
|
) |
|
|
|
results = [r.mean().item() for r in evaluator(views)] |
|
pred_score = fuse_results(results) |
|
print(f"path, semantic score, technical score, aesthetic score, overall/final score") |
|
print(f'{args.video_path.split("/")[-1]},{pred_score["semantic"]:4f},{pred_score["technical"]:4f},{pred_score["aesthetic"]:4f},{pred_score["overall"]:4f}') |
|
|
|
|