from streamlit_webrtc import webrtc_streamer import numpy as np import streamlit as st import numpy as np import av import threading import multiprocessing from typing import List, Optional, Tuple from pandas import DataFrame import numpy as np import pandas as pd import streamlit as st import torch from torch import Tensor from transformers import AutoFeatureExtractor, TimesformerForVideoClassification from utils.frame_rate import FrameRate np.random.seed(0) st.set_page_config( page_title="TimeSFormer", page_icon="🧊", layout="wide", initial_sidebar_state="expanded", menu_items={ "Get Help": "https://www.extremelycoolapp.com/help", "Report a bug": "https://www.extremelycoolapp.com/bug", "About": "# This is a header. This is an *extremely* cool app!", }, ) @st.cache_resource # @st.experimental_singleton def load_model(model_name: str): if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name: feature_extractor = AutoFeatureExtractor.from_pretrained( "MCG-NJU/videomae-base-finetuned-kinetics" ) else: feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) model = TimesformerForVideoClassification.from_pretrained(model_name) return feature_extractor, model lock = threading.Lock() rtc_configuration = { "iceServers": [ { "urls": "turn:relay1.expressturn.com:3478", "username": "efBRTY571ATWBRMP36", "credential": "pGcX1BPH5fMmZJc5", }, # { # "urls": [ # "stun:stun1.l.google.com:19302", # "stun:stun2.l.google.com:19302", # "stun:stun3.l.google.com:19302", # "stun:stun4.l.google.com:19302", # ] # }, ], } def inference(): if not img_container.ready: return inputs = feature_extractor(list(img_container.imgs), return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits: Tensor = outputs.logits # model predicts one of the 400 Kinetics-400 classes max_index = logits.argmax(-1).item() predicted_label = model.config.id2label[max_index] img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%" TOP_K = 12 # logits = np.squeeze(logits) logits = logits.squeeze().numpy() indices = np.argsort(logits)[::-1][:TOP_K] values = logits[indices] results: List[Tuple[str, float]] = [] for index, value in zip(indices, values): predicted_label = model.config.id2label[index] # print(f"Label: {predicted_label} - {value:.2f}%") results.append((predicted_label, value)) img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence")) class ImgContainer: def __init__(self, frames_per_video: int = 8) -> None: self.img: Optional[np.ndarray] = None # raw image self.frame_rate: FrameRate = FrameRate() self.imgs: List[np.ndarray] = [] self.frame_rate.reset() self.frames_per_video = frames_per_video self.rs: Optional[DataFrame] = None def add_frame(self, frame: np.ndarray): if len(img_container.imgs) >= frames_per_video: self.imgs.pop(0) self.imgs.append(frame) @property def ready(self): return len(img_container.imgs) == self.frames_per_video def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") with lock: img_container.img = img img_container.frame_rate.count() img_container.add_frame(img) inference() img = img_container.frame_rate.show_fps(img) return av.VideoFrame.from_ndarray(img, format="bgr24") def get_frames_per_video(model_name: str) -> int: if "base-finetuned" in model_name: return 8 elif "hr-finetuned" in model_name: return 16 else: return 96 st.title("TimeSFormer") with st.expander("INTRODUCTION"): st.text( f"""Streamlit demo for TimeSFormer. Number of CPU(s): {multiprocessing.cpu_count()} """ ) model_name = st.selectbox( "model_name", ( "facebook/timesformer-base-finetuned-k400", "facebook/timesformer-base-finetuned-k600", "facebook/timesformer-base-finetuned-ssv2", "facebook/timesformer-hr-finetuned-k600", "facebook/timesformer-hr-finetuned-k400", "facebook/timesformer-hr-finetuned-ssv2", "fcakyon/timesformer-large-finetuned-k400", "fcakyon/timesformer-large-finetuned-k600", ), ) feature_extractor, model = load_model(model_name) frames_per_video = get_frames_per_video(model_name) st.info(f"Frames per video: {frames_per_video}") img_container = ImgContainer(frames_per_video) ctx = st.session_state.ctx = webrtc_streamer( key="snapshot", video_frame_callback=video_frame_callback, rtc_configuration=rtc_configuration, ) if img_container.rs is not None: st.dataframe(img_container.rs)