TimeSFormer / run_opencv.py
thinh-huynh-re's picture
Update
254ea49
raw
history blame
4.81 kB
from typing import List, Optional, Tuple
import cv2
import numpy as np
import pandas as pd
import torch
from tap import Tap
from torch import Tensor
from transformers import (
AutoFeatureExtractor,
TimesformerForVideoClassification,
VideoMAEFeatureExtractor,
)
from utils.img_container import ImgContainer
class ArgParser(Tap):
is_recording: Optional[bool] = False
# "facebook/timesformer-base-finetuned-k400"
# "facebook/timesformer-base-finetuned-k600",
# "facebook/timesformer-base-finetuned-ssv2",
# "facebook/timesformer-hr-finetuned-k600",
# "facebook/timesformer-hr-finetuned-k400",
# "facebook/timesformer-hr-finetuned-ssv2",
# "fcakyon/timesformer-large-finetuned-k400",
# "fcakyon/timesformer-large-finetuned-k600",
model_name: Optional[str] = "facebook/timesformer-base-finetuned-k400"
num_skip_frames: Optional[int] = 4
top_k: Optional[int] = 5
class ActivityModel:
def __init__(self, args: ArgParser):
self.feature_extractor, self.model = self.load_model(args.model_name)
self.args = args
self.frames_per_video = self.get_frames_per_video(args.model_name)
print(f"Frames per video: {self.frames_per_video}")
def load_model(
self, model_name: str
) -> Tuple[VideoMAEFeatureExtractor, TimesformerForVideoClassification]:
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
feature_extractor = AutoFeatureExtractor.from_pretrained(
"MCG-NJU/videomae-base-finetuned-kinetics"
)
else:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
model = TimesformerForVideoClassification.from_pretrained(model_name)
return feature_extractor, model
def inference(self, img_container: ImgContainer):
if not img_container.ready:
return
inputs = self.feature_extractor(list(img_container.imgs), return_tensors="pt")
with torch.no_grad():
outputs = self.model(**inputs)
logits: Tensor = outputs.logits
# model predicts one of the 400 Kinetics-400 classes
max_index = logits.argmax(-1).item()
predicted_label = self.model.config.id2label[max_index]
img_container.frame_rate.label = (
f"{predicted_label}_{logits[0][max_index]:.2f}%"
)
# logits = np.squeeze(logits)
logits = logits.squeeze().numpy()
indices = np.argsort(logits)[::-1][: self.args.top_k]
values = logits[indices]
results: List[Tuple[str, float]] = []
for index, value in zip(indices, values):
predicted_label = self.model.config.id2label[index]
# print(f"Label: {predicted_label} - {value:.2f}%")
results.append((predicted_label, value))
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
def get_frames_per_video(self, model_name: str) -> int:
if "base-finetuned" in model_name:
return 8
elif "hr-finetuned" in model_name:
return 16
else:
return 96
def main(args: ArgParser):
activity_model = ActivityModel(args)
img_container = ImgContainer(activity_model.frames_per_video)
num_skips = 0
# define a video capture object
camera = cv2.VideoCapture(0)
frame_width = int(camera.get(3))
frame_height = int(camera.get(4))
size = (frame_width, frame_height)
video_output = cv2.VideoWriter(
"activities.mp4", cv2.VideoWriter_fourcc(*"MP4V"), 10, size
)
if camera.isOpened() == False:
print("Error reading video file")
while camera.isOpened():
# Capture the video frame
# by frame
ret, frame = camera.read()
num_skips = (num_skips + 1) % args.num_skip_frames
img_container.img = frame
img_container.frame_rate.count()
if num_skips == 0:
img_container.add_frame(frame)
activity_model.inference(img_container)
rs = img_container.frame_rate.show_fps(frame, img_container.is_recording)
# Display the resulting frame
cv2.imshow("ActivityTracking", rs)
if img_container.is_recording:
video_output.write(rs)
# the 'q' button is set as the
# quitting button you may use any
# desired button of your choice
k = cv2.waitKey(1)
if k == ord("q"):
break
elif k == ord("r"):
img_container.toggle_recording()
# After the loop release the cap object
camera.release()
video_output.release()
# Destroy all the windows
cv2.destroyAllWindows()
if __name__ == "__main__":
args = ArgParser().parse_args()
main(args)