omermazig's picture
Get labels in original order (Got them in the wrong order before which effected the results)
596e2ad
raw
history blame
4.58 kB
import gradio as gr
import tempfile
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
Compose,
Lambda,
Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
from video_utils import change_video_resolution_and_fps
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel-4"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Validation and Test datasets' transformations.
inference_transform = Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
]
)
num_labels = trained_model.config.num_labels
labels = [trained_model.config.id2label[i] for i in range(num_labels)]
def parse_video_to_clips(video_file):
"""A utility to parse the input videos """
new_resolution = (320, 256)
new_fps = 30
acceptable_fps_violation = 5
with tempfile.NamedTemporaryFile() as new_video:
print(new_video.name)
change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation)
video_path_handler = VideoPathHandler()
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
# noinspection PyTypeChecker
clip_info: ClipInfoList = clip_sampler(0, video.duration, {})
video_clips_list = []
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
video_clip = video.get_clip(clip_start, clip_end)["video"]
video_clips_list.append(inference_transform(video_clip))
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
return videos_tensor.to(DEVICE)
def infer(video_file, threshold=0.5):
videos_tensor = parse_video_to_clips(video_file)
inputs = {"pixel_values": videos_tensor}
# forward pass
with torch.no_grad():
outputs = trained_model(**inputs)
multiple_logits = outputs.logits
logits = multiple_logits.mean(dim=0)
# first, apply sigmoid on logits
sigmoid = torch.nn.Sigmoid()
sigmoid_scores = sigmoid(torch.Tensor(logits)).squeeze(0)
# next, use threshold to turn them into integer predictions
confidences = {labels[i]: float(sigmoid_scores[i]) for i in range(len(labels))}
return confidences
gr.Interface(
fn=infer,
inputs=gr.Video(type="file"),
outputs=gr.Label(num_top_classes=3),
examples=[
["examples/DUNK.avi"],
["examples/FLOATING_JUMP_SHOT.avi"],
["examples/JUMP_SHOT.avi"],
["examples/REVERSE_LAYUP.avi"],
["examples/TURNAROUND_HOOK_SHOT.avi"],
],
title="VideoMAE fine-tuned on nba data",
description=(
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
" examples to load them. Read more at the links below."
),
article=(
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188' target='_blank'>Fine-tuned Model</a></center></div>"
),
allow_flagging=False,
allow_screenshot=False,
).launch()