Spaces:
Runtime error
Runtime error
File size: 4,582 Bytes
8760fb5 70ecfd8 8760fb5 2ac5dcc 8760fb5 2ac5dcc 8760fb5 2ac5dcc 8760fb5 944e963 8760fb5 70ecfd8 ed24b85 2ac5dcc 944e963 2ac5dcc 596e2ad 2ac5dcc 70ecfd8 63e5dc3 70ecfd8 63e5dc3 70ecfd8 2ac5dcc 944e963 8760fb5 ed24b85 2ac5dcc 8760fb5 2ac5dcc ed24b85 8760fb5 96f34ba 8760fb5 96f34ba 8760fb5 96f34ba 8760fb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
import tempfile
import torch
from pytorchvideo.data import make_clip_sampler
from pytorchvideo.data.clip_sampling import ClipInfoList
from pytorchvideo.data.encoded_video_pyav import EncodedVideoPyAV
from pytorchvideo.data.video import VideoPathHandler
from pytorchvideo.transforms import (
Normalize,
UniformTemporalSubsample, RandomShortSideScale,
)
from torchvision.transforms import (
Compose,
Lambda,
Resize, RandomCrop,
)
from transformers import VideoMAEForVideoClassification, VideoMAEFeatureExtractor
from video_utils import change_video_resolution_and_fps
MODEL_CKPT = "omermazig/videomae-finetuned-nba-5-class-4-batch-8000-vid-multilabel-4"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
CLIPS_FROM_SINGLE_VIDEO = 5
trained_model = VideoMAEForVideoClassification.from_pretrained(MODEL_CKPT).to(DEVICE)
image_processor = VideoMAEFeatureExtractor.from_pretrained(MODEL_CKPT)
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
height = width = image_processor.size["shortest_edge"]
else:
height = image_processor.size["height"]
width = image_processor.size["width"]
resize_to = (height, width)
num_frames_to_sample = trained_model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps
# Validation and Test datasets' transformations.
inference_transform = Compose(
[
UniformTemporalSubsample(num_frames_to_sample),
Lambda(lambda x: x / 255.0),
Normalize(mean, std),
RandomShortSideScale(min_size=256, max_size=320),
RandomCrop(resize_to),
]
)
num_labels = trained_model.config.num_labels
labels = [trained_model.config.id2label[i] for i in range(num_labels)]
def parse_video_to_clips(video_file):
"""A utility to parse the input videos """
new_resolution = (320, 256)
new_fps = 30
acceptable_fps_violation = 5
with tempfile.NamedTemporaryFile() as new_video:
print(new_video.name)
change_video_resolution_and_fps(video_file, new_video.name, new_resolution, new_fps, acceptable_fps_violation)
video_path_handler = VideoPathHandler()
video: EncodedVideoPyAV = video_path_handler.video_from_path(video_file)
clip_sampler = make_clip_sampler("random_multi", clip_duration, CLIPS_FROM_SINGLE_VIDEO)
# noinspection PyTypeChecker
clip_info: ClipInfoList = clip_sampler(0, video.duration, {})
video_clips_list = []
for clip_start, clip_end in zip(clip_info.clip_start_sec, clip_info.clip_end_sec):
video_clip = video.get_clip(clip_start, clip_end)["video"]
video_clips_list.append(inference_transform(video_clip))
videos_tensor = torch.stack([single_clip.permute(1, 0, 2, 3) for single_clip in video_clips_list])
return videos_tensor.to(DEVICE)
def infer(video_file, threshold=0.5):
videos_tensor = parse_video_to_clips(video_file)
inputs = {"pixel_values": videos_tensor}
# forward pass
with torch.no_grad():
outputs = trained_model(**inputs)
multiple_logits = outputs.logits
logits = multiple_logits.mean(dim=0)
# first, apply sigmoid on logits
sigmoid = torch.nn.Sigmoid()
sigmoid_scores = sigmoid(torch.Tensor(logits)).squeeze(0)
# next, use threshold to turn them into integer predictions
confidences = {labels[i]: float(sigmoid_scores[i]) for i in range(len(labels))}
return confidences
gr.Interface(
fn=infer,
inputs=gr.Video(type="file"),
outputs=gr.Label(num_top_classes=3),
examples=[
["examples/DUNK.avi"],
["examples/FLOATING_JUMP_SHOT.avi"],
["examples/JUMP_SHOT.avi"],
["examples/REVERSE_LAYUP.avi"],
["examples/TURNAROUND_HOOK_SHOT.avi"],
],
title="VideoMAE fine-tuned on nba data",
description=(
"Gradio demo for VideoMAE for video classification. To use it, simply upload your video or click one of the"
" examples to load them. Read more at the links below."
),
article=(
"<div style='text-align: center;'><a href='https://huggingface.co/docs/transformers/model_doc/videomae' target='_blank'>VideoMAE</a>"
" <center><a href='https://huggingface.co/omermazig/videomae-finetuned-nba-5-class-8-batch-8000-vid-multiclass_1697155188' target='_blank'>Fine-tuned Model</a></center></div>"
),
allow_flagging=False,
allow_screenshot=False,
).launch()
|