Spaces:
Runtime error
Runtime error
import gradio as gr | |
import tempfile | |
import torch | |
from pytorchvideo.data import make_clip_sampler | |
from pytorchvideo.data.clip_sampling import ClipInfoList | |
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV | |
from pytorchvideo.data.video import VideoPathHandler | |
from pytorchvideo.transforms import ( | |
Normalize, | |
UniformTemporalSubsample, RandomShortSideScale, | |
) | |
from torchvision.transforms import ( | |
Compose, | |
Lambda, | |
Resize, RandomCrop, | |
) | |
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor | |
from video_utils import change_video_resolution_and_fps | |
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel-4" | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
CLIPS_FROM_SINGLE_VIDEO = 5 | |
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE) | |
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT) | |
mean = image_processor.image_mean | |
std = image_processor.image_std | |
if "shortest_edge" in image_processor.size: | |
height = width = image_processor.size["shortest_edge"] | |
else: | |
height = image_processor.size["height"] | |
width = image_processor.size["width"] | |
resize_to = (height, width) | |
num_frames_to_sample = trained_model.config.num_frames | |
sample_rate = 4 | |
fps = 30 | |
clip_duration = num_frames_to_sample * sample_rate / fps | |
# Validation and Test datasets' transformations. | |
inference_transform = Compose( | |
[ | |
UniformTemporalSubsample(num_frames_to_sample), | |
Lambda(lambda x: x / 255.0), | |
Normalize(mean, std), | |
RandomShortSideScale(min_size=256, max_size=320), | |
RandomCrop(resize_to), | |
] | |
) | |
num_labels = trained_model.config.num_labels | |
labels = [trained_model.config.id2label[i] for i in range(num_labels)] | |
def parse_video_to_clips(video_file): | |
"""A utility to parse the input videos """ | |
new_resolution = (320, 256) | |
new_fps = 30 | |
acceptable_fps_violation = 5 | |
with tempfile.NamedTemporaryFile() as new_video: | |
print(new_video.name) | |
change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation) | |
video_path_handler = VideoPathHandler() | |
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file) | |
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO) | |
# noinspection PyTypeChecker | |
clip_info: ClipInfoList = clip_sampler(0, video.duration, {}) | |
video_clips_list = [] | |
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec): | |
video_clip = video.get_clip(clip_start, clip_end)["video"] | |
video_clips_list.append(inference_transform(video_clip)) | |
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list]) | |
return videos_tensor.to(DEVICE) | |
def infer(video_file, threshold=0.5): | |
videos_tensor = parse_video_to_clips(video_file) | |
inputs = {"pixel_values": videos_tensor} | |
# forward pass | |
with torch.no_grad(): | |
outputs = trained_model(**inputs) | |
multiple_logits = outputs.logits | |
logits = multiple_logits.mean(dim=0) | |
# first, apply sigmoid on logits | |
sigmoid = torch.nn.Sigmoid() | |
sigmoid_scores = sigmoid(torch.Tensor(logits)).squeeze(0) | |
# next, use threshold to turn them into integer predictions | |
confidences = {labels[i]: float(sigmoid_scores[i]) for i in range(len(labels))} | |
return confidences | |
gr.Interface( | |
fn=infer, | |
inputs=gr.Video(type="file"), | |
outputs=gr.Label(num_top_classes=3), | |
examples=[ | |
["examples/DUNK.avi"], | |
["examples/FLOATING_JUMP_SHOT.avi"], | |
["examples/JUMP_SHOT.avi"], | |
["examples/REVERSE_LAYUP.avi"], | |
["examples/TURNAROUND_HOOK_SHOT.avi"], | |
], | |
title="VideoMAE fine-tuned on nba data", | |
description=( | |
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the" | |
" examples to load them. Read more at the links below." | |
), | |
article=( | |
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>" | |
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188' target='_blank'>Fine-tuned Model</a></center></div>" | |
), | |
allow_flagging=False, | |
allow_screenshot=False, | |
).launch() | |