TimeSFormer / test_opencv.py
giahan's picture
Use opencv only (no streamlit)
9e3c23c
raw
history blame
3.75 kB
from typing import List, Optional, Tuple
import cv2
from pandas import DataFrame
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
import numpy as np
import torch
import pandas as pd
from torch import Tensor
from utils.frame_rate import FrameRate
def load_model(model_name: str):
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
class ImgContainer:
def __init__(self, frames_per_video: int = 8) -> None:
self.img: Optional[np.ndarray] = None # raw image
self.frame_rate: FrameRate = FrameRate()
self.imgs: List[np.ndarray] = []
self.frame_rate.reset()
self.frames_per_video = frames_per_video
self.rs: Optional[DataFrame] = None
def add_frame(self, frame: np.ndarray):
if len(img_container.imgs) >= frames_per_video:
self.imgs.pop(0)
self.imgs.append(frame)
@property
def ready(self):
return len(img_container.imgs) == self.frames_per_video
def inference():
if not img_container.ready:
return
inputs = feature_extractor(list(img_container.imgs), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
max_index = logits.argmax(-1).item()
predicted_label = model.config.id2label[max_index]
img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
TOP_K = 12
# logits = np.squeeze(logits)
logits = logits.squeeze().numpy()
indices = np.argsort(logits)[::-1][:TOP_K]
values = logits[indices]
results: List[Tuple[str, float]] = []
for index, value in zip(indices, values):
predicted_label = model.config.id2label[index]
# print(f"Label: {predicted_label} - {value:.2f}%")
results.append((predicted_label, value))
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
def get_frames_per_video(model_name: str) -> int:
if "base-finetuned" in model_name:
return 8
elif "hr-finetuned" in model_name:
return 16
else:
return 96
model_name = "facebook/timesformer-base-finetuned-k400"
# "facebook/timesformer-base-finetuned-k600",
# "facebook/timesformer-base-finetuned-ssv2",
# "facebook/timesformer-hr-finetuned-k600",
# "facebook/timesformer-hr-finetuned-k400",
# "facebook/timesformer-hr-finetuned-ssv2",
# "fcakyon/timesformer-large-finetuned-k400",
# "fcakyon/timesformer-large-finetuned-k600",
feature_extractor, model = load_model(model_name)
frames_per_video = get_frames_per_video(model_name)
print(f"Frames per video: {frames_per_video}")
img_container = ImgContainer(frames_per_video)
# define a video capture object
vid = cv2.VideoCapture(0)
while(True):
# Capture the video frame
# by frame
ret, frame = vid.read()
img_container.img = frame
img_container.frame_rate.count()
img_container.add_frame(frame)
inference()
rs = img_container.frame_rate.show_fps(frame)
# Display the resulting frame
cv2.imshow('TimeSFormer', rs)
# the 'q' button is set as the
# quitting button you may use any
# desired button of your choice
if cv2.waitKey(1) & 0xFF == ord('q'):
break
# After the loop release the cap object
vid.release()
# Destroy all the windows
cv2.destroyAllWindows()