giahan commited on
Commit
9e3c23c
1 Parent(s): e94e369

Use opencv only (no streamlit)

Browse files
Files changed (1) hide show
  1. test_opencv.py +121 -0
test_opencv.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple
2
+ import cv2
3
+ from pandas import DataFrame
4
+ from transformers import AutoFeatureExtractor, TimesformerForVideoClassification
5
+ import numpy as np
6
+ import torch
7
+ import pandas as pd
8
+ from torch import Tensor
9
+
10
+ from utils.frame_rate import FrameRate
11
+
12
+ def load_model(model_name: str):
13
+ if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name:
14
+ feature_extractor = AutoFeatureExtractor.from_pretrained(
15
+ "MCG-NJU/videomae-base-finetuned-kinetics"
16
+ )
17
+ else:
18
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
19
+ model = TimesformerForVideoClassification.from_pretrained(model_name)
20
+ return feature_extractor, model
21
+
22
+ class ImgContainer:
23
+ def __init__(self, frames_per_video: int = 8) -> None:
24
+ self.img: Optional[np.ndarray] = None # raw image
25
+ self.frame_rate: FrameRate = FrameRate()
26
+ self.imgs: List[np.ndarray] = []
27
+ self.frame_rate.reset()
28
+ self.frames_per_video = frames_per_video
29
+ self.rs: Optional[DataFrame] = None
30
+
31
+ def add_frame(self, frame: np.ndarray):
32
+ if len(img_container.imgs) >= frames_per_video:
33
+ self.imgs.pop(0)
34
+ self.imgs.append(frame)
35
+
36
+ @property
37
+ def ready(self):
38
+ return len(img_container.imgs) == self.frames_per_video
39
+
40
+ def inference():
41
+ if not img_container.ready:
42
+ return
43
+
44
+ inputs = feature_extractor(list(img_container.imgs), return_tensors="pt")
45
+
46
+ with torch.no_grad():
47
+ outputs = model(**inputs)
48
+ logits: Tensor = outputs.logits
49
+
50
+ # model predicts one of the 400 Kinetics-400 classes
51
+ max_index = logits.argmax(-1).item()
52
+ predicted_label = model.config.id2label[max_index]
53
+
54
+ img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%"
55
+
56
+ TOP_K = 12
57
+ # logits = np.squeeze(logits)
58
+ logits = logits.squeeze().numpy()
59
+ indices = np.argsort(logits)[::-1][:TOP_K]
60
+ values = logits[indices]
61
+
62
+ results: List[Tuple[str, float]] = []
63
+ for index, value in zip(indices, values):
64
+ predicted_label = model.config.id2label[index]
65
+ # print(f"Label: {predicted_label} - {value:.2f}%")
66
+ results.append((predicted_label, value))
67
+
68
+ img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence"))
69
+
70
+ def get_frames_per_video(model_name: str) -> int:
71
+ if "base-finetuned" in model_name:
72
+ return 8
73
+ elif "hr-finetuned" in model_name:
74
+ return 16
75
+ else:
76
+ return 96
77
+
78
+
79
+ model_name = "facebook/timesformer-base-finetuned-k400"
80
+ # "facebook/timesformer-base-finetuned-k600",
81
+ # "facebook/timesformer-base-finetuned-ssv2",
82
+ # "facebook/timesformer-hr-finetuned-k600",
83
+ # "facebook/timesformer-hr-finetuned-k400",
84
+ # "facebook/timesformer-hr-finetuned-ssv2",
85
+ # "fcakyon/timesformer-large-finetuned-k400",
86
+ # "fcakyon/timesformer-large-finetuned-k600",
87
+ feature_extractor, model = load_model(model_name)
88
+
89
+
90
+ frames_per_video = get_frames_per_video(model_name)
91
+ print(f"Frames per video: {frames_per_video}")
92
+
93
+ img_container = ImgContainer(frames_per_video)
94
+
95
+ # define a video capture object
96
+ vid = cv2.VideoCapture(0)
97
+
98
+ while(True):
99
+ # Capture the video frame
100
+ # by frame
101
+ ret, frame = vid.read()
102
+
103
+ img_container.img = frame
104
+ img_container.frame_rate.count()
105
+ img_container.add_frame(frame)
106
+ inference()
107
+ rs = img_container.frame_rate.show_fps(frame)
108
+
109
+ # Display the resulting frame
110
+ cv2.imshow('TimeSFormer', rs)
111
+
112
+ # the 'q' button is set as the
113
+ # quitting button you may use any
114
+ # desired button of your choice
115
+ if cv2.waitKey(1) & 0xFF == ord('q'):
116
+ break
117
+
118
+ # After the loop release the cap object
119
+ vid.release()
120
+ # Destroy all the windows
121
+ cv2.destroyAllWindows()