import gradio as gr import torch import argparse import pickle as pkl import decord from decord import VideoReader import numpy as np import yaml from cover.datasets import UnifiedFrameSampler, spatial_temporal_view_decomposition from cover.models import COVER mean, std = ( torch.FloatTensor([123.675, 116.28, 103.53]), torch.FloatTensor([58.395, 57.12, 57.375]), ) mean_clip, std_clip = ( torch.FloatTensor([122.77, 116.75, 104.09]), torch.FloatTensor([68.50, 66.63, 70.32]) ) def fuse_results(results: list): x = (results[0] + results[1] + results[2]) return { "semantic" : results[0], "technical": results[1], "aesthetic": results[2], "overall" : x, } def inference_one_video(input_video): """ BASIC SETTINGS """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open("./cover.yml", "r") as f: opt = yaml.safe_load(f) dopt = opt["data"]["val-ytugc"]["args"] temporal_samplers = {} for stype, sopt in dopt["sample_types"].items(): temporal_samplers[stype] = UnifiedFrameSampler( sopt["clip_len"] // sopt["t_frag"], sopt["t_frag"], sopt["frame_interval"], sopt["num_clips"], ) """ LOAD MODEL """ evaluator = COVER(**opt["model"]["args"]).to(device) state_dict = torch.load(opt["test_load_path"], map_location=device) # set strict=False here to avoid error of missing # weight of prompt_learner in clip-iqa+, cross-gate evaluator.load_state_dict(state_dict['state_dict'], strict=False) """ TESTING """ views, _ = spatial_temporal_view_decomposition( input_video, dopt["sample_types"], temporal_samplers ) for k, v in views.items(): num_clips = dopt["sample_types"][k].get("num_clips", 1) if k == 'technical' or k == 'aesthetic': views[k] = ( ((v.permute(1, 2, 3, 0) - mean) / std) .permute(3, 0, 1, 2) .reshape(v.shape[0], num_clips, -1, *v.shape[2:]) .transpose(0, 1) .to(device) ) elif k == 'semantic': views[k] = ( ((v.permute(1, 2, 3, 0) - mean_clip) / std_clip) .permute(3, 0, 1, 2) .reshape(v.shape[0], num_clips, -1, *v.shape[2:]) .transpose(0, 1) .to(device) ) results = [r.mean().item() for r in evaluator(views)] pred_score = fuse_results(results) return pred_score # Define the input and output types for Gradio using the new API video_input = gr.Video(label="Input Video") output_label = gr.JSON(label="Scores") # Create the Gradio interface gradio_app = gr.Interface(fn=inference_one_video, inputs=video_input, outputs=output_label) if __name__ == "__main__": gradio_app.launch()