File size: 4,582 Bytes
8760fb5
70ecfd8
8760fb5
2ac5dcc
 
 
 
8760fb5
 
2ac5dcc
8760fb5
 
 
 
2ac5dcc
8760fb5
944e963
8760fb5
70ecfd8
 
ed24b85
2ac5dcc
 
 
944e963
 
2ac5dcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
596e2ad
 
2ac5dcc
 
 
 
70ecfd8
 
63e5dc3
70ecfd8
 
63e5dc3
70ecfd8
 
2ac5dcc
 
 
 
 
 
 
 
 
 
 
944e963
8760fb5
 
ed24b85
2ac5dcc
 
8760fb5
 
 
2ac5dcc
 
ed24b85
 
 
 
 
 
 
8760fb5
 
 
 
 
 
 
 
96f34ba
 
 
 
 
8760fb5
96f34ba
8760fb5
 
 
 
 
 
96f34ba
8760fb5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
import tempfile
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
    Normalize,
    UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
    Compose,
    Lambda,
    Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor

from video_utils import change_video_resolution_and_fps

MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel-4"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5

trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)

mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

# Validation and Test datasets' transformations.
inference_transform = Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    RandomShortSideScale(min_size=256, max_size=320),
                    RandomCrop(resize_to),
                ]
            )

num_labels = trained_model.config.num_labels
labels = [trained_model.config.id2label[i] for i in range(num_labels)]


def parse_video_to_clips(video_file):
    """A utility to parse the input videos """
    new_resolution = (320, 256)
    new_fps = 30
    acceptable_fps_violation = 5
    with tempfile.NamedTemporaryFile() as new_video:
        print(new_video.name)
        change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation)
        video_path_handler = VideoPathHandler()
        video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)

    clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
    # noinspection PyTypeChecker
    clip_info: ClipInfoList = clip_sampler(0, video.duration, {})

    video_clips_list = []
    for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
        video_clip = video.get_clip(clip_start, clip_end)["video"]
        video_clips_list.append(inference_transform(video_clip))

    videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
    return videos_tensor.to(DEVICE)


def infer(video_file, threshold=0.5):
    videos_tensor = parse_video_to_clips(video_file)
    inputs = {"pixel_values": videos_tensor}

    # forward pass
    with torch.no_grad():
        outputs = trained_model(**inputs)
        multiple_logits = outputs.logits
        logits = multiple_logits.mean(dim=0)

    # first, apply sigmoid on logits
    sigmoid = torch.nn.Sigmoid()
    sigmoid_scores = sigmoid(torch.Tensor(logits)).squeeze(0)
    # next, use threshold to turn them into integer predictions
    confidences = {labels[i]: float(sigmoid_scores[i]) for i in range(len(labels))}
    return confidences


gr.Interface(
    fn=infer,
    inputs=gr.Video(type="file"),
    outputs=gr.Label(num_top_classes=3),
    examples=[
        ["examples/DUNK.avi"],
        ["examples/FLOATING_JUMP_SHOT.avi"],
        ["examples/JUMP_SHOT.avi"],
        ["examples/REVERSE_LAYUP.avi"],
        ["examples/TURNAROUND_HOOK_SHOT.avi"],
    ],
    title="VideoMAE fine-tuned on nba data",
    description=(
        "Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
        " examples to load them. Read more at the links below."
    ),
    article=(
        "<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
        " <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188' target='_blank'>Fine-tuned Model</a></center></div>"
    ),
    allow_flagging=False,
    allow_screenshot=False,
).launch()