TimeSFormer / run_opencv.py
thinh-huynh-re's picture
Refactor
1e87f84
raw
history blame
3.25 kB
from typing import List, Tuple
import cv2
import numpy as np
import pandas as pd
import torch
from torch import Tensor
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
from utils.img_container import ImgContainer
def load_model(model_name: str):
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
def inference():
if not img_container.ready:
return
inputs = feature_extractor(list(img_container.imgs), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
max_index = logits.argmax(-1).item()
predicted_label = model.config.id2label[max_index]
img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
TOP_K = 12
# logits = np.squeeze(logits)
logits = logits.squeeze().numpy()
indices = np.argsort(logits)[::-1][:TOP_K]
values = logits[indices]
results: List[Tuple[str, float]] = []
for index, value in zip(indices, values):
predicted_label = model.config.id2label[index]
# print(f"Label: {predicted_label} - {value:.2f}%")
results.append((predicted_label, value))
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
def get_frames_per_video(model_name: str) -> int:
if "base-finetuned" in model_name:
return 8
elif "hr-finetuned" in model_name:
return 16
else:
return 96
model_name = "facebook/timesformer-base-finetuned-k400"
# "facebook/timesformer-base-finetuned-k400"
# "facebook/timesformer-base-finetuned-k600",
# "facebook/timesformer-base-finetuned-ssv2",
# "facebook/timesformer-hr-finetuned-k600",
# "facebook/timesformer-hr-finetuned-k400",
# "facebook/timesformer-hr-finetuned-ssv2",
# "fcakyon/timesformer-large-finetuned-k400",
# "fcakyon/timesformer-large-finetuned-k600",
feature_extractor, model = load_model(model_name)
frames_per_video = get_frames_per_video(model_name)
print(f"Frames per video: {frames_per_video}")
img_container = ImgContainer(frames_per_video)
SKIP_FRAMES = 4
num_skips = 0
# define a video capture object
vid = cv2.VideoCapture(0)
while True:
# Capture the video frame
# by frame
ret, frame = vid.read()
num_skips = (num_skips + 1) % SKIP_FRAMES
img_container.img = frame
img_container.frame_rate.count()
if num_skips == 0:
img_container.add_frame(frame)
inference()
rs = img_container.frame_rate.show_fps(frame)
# Display the resulting frame
cv2.imshow("TimeSFormer", rs)
# the 'q' button is set as the
# quitting button you may use any
# desired button of your choice
if cv2.waitKey(1) & 0xFF == ord("q"):
break
# After the loop release the cap object
vid.release()
# Destroy all the windows
cv2.destroyAllWindows()