Spaces:
Runtime error
Runtime error
from streamlit_webrtc import webrtc_streamer | |
import numpy as np | |
import streamlit as st | |
import numpy as np | |
import av | |
import threading | |
import multiprocessing | |
from typing import List, Optional, Tuple | |
from pandas import DataFrame | |
import numpy as np | |
import pandas as pd | |
import streamlit as st | |
import torch | |
from torch import Tensor | |
from transformers import AutoFeatureExtractor, TimesformerForVideoClassification | |
from utils.frame_rate import FrameRate | |
np.random.seed(0) | |
st.set_page_config( | |
page_title="TimeSFormer", | |
page_icon="🧊", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
menu_items={ | |
"Get Help": "https://www.extremelycoolapp.com/help", | |
"Report a bug": "https://www.extremelycoolapp.com/bug", | |
"About": "# This is a header. This is an *extremely* cool app!", | |
}, | |
) | |
# @st.experimental_singleton | |
def load_model(model_name: str): | |
if "base-finetuned-k400" in model_name or "base-finetuned-k600" in model_name: | |
feature_extractor = AutoFeatureExtractor.from_pretrained( | |
"MCG-NJU/videomae-base-finetuned-kinetics" | |
) | |
else: | |
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
model = TimesformerForVideoClassification.from_pretrained(model_name) | |
return feature_extractor, model | |
lock = threading.Lock() | |
rtc_configuration = { | |
"iceServers": [ | |
{ | |
"urls": "turn:relay1.expressturn.com:3478", | |
"username": "efBRTY571ATWBRMP36", | |
"credential": "pGcX1BPH5fMmZJc5", | |
}, | |
# { | |
# "urls": [ | |
# "stun:stun1.l.google.com:19302", | |
# "stun:stun2.l.google.com:19302", | |
# "stun:stun3.l.google.com:19302", | |
# "stun:stun4.l.google.com:19302", | |
# ] | |
# }, | |
], | |
} | |
def inference(): | |
if not img_container.ready: | |
return | |
inputs = feature_extractor(list(img_container.imgs), return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits: Tensor = outputs.logits | |
# model predicts one of the 400 Kinetics-400 classes | |
max_index = logits.argmax(-1).item() | |
predicted_label = model.config.id2label[max_index] | |
img_container.frame_rate.label = f"{predicted_label}_{logits[0][max_index]:.2f}%" | |
TOP_K = 12 | |
# logits = np.squeeze(logits) | |
logits = logits.squeeze().numpy() | |
indices = np.argsort(logits)[::-1][:TOP_K] | |
values = logits[indices] | |
results: List[Tuple[str, float]] = [] | |
for index, value in zip(indices, values): | |
predicted_label = model.config.id2label[index] | |
# print(f"Label: {predicted_label} - {value:.2f}%") | |
results.append((predicted_label, value)) | |
img_container.rs = pd.DataFrame(results, columns=("Label", "Confidence")) | |
class ImgContainer: | |
def __init__(self, frames_per_video: int = 8) -> None: | |
self.img: Optional[np.ndarray] = None # raw image | |
self.frame_rate: FrameRate = FrameRate() | |
self.imgs: List[np.ndarray] = [] | |
self.frame_rate.reset() | |
self.frames_per_video = frames_per_video | |
self.rs: Optional[DataFrame] = None | |
def add_frame(self, frame: np.ndarray): | |
if len(img_container.imgs) >= frames_per_video: | |
self.imgs.pop(0) | |
self.imgs.append(frame) | |
def ready(self): | |
return len(img_container.imgs) == self.frames_per_video | |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
img = frame.to_ndarray(format="bgr24") | |
with lock: | |
img_container.img = img | |
img_container.frame_rate.count() | |
img_container.add_frame(img) | |
inference() | |
img = img_container.frame_rate.show_fps(img) | |
return av.VideoFrame.from_ndarray(img, format="bgr24") | |
def get_frames_per_video(model_name: str) -> int: | |
if "base-finetuned" in model_name: | |
return 8 | |
elif "hr-finetuned" in model_name: | |
return 16 | |
else: | |
return 96 | |
st.title("TimeSFormer") | |
with st.expander("INTRODUCTION"): | |
st.text( | |
f"""Streamlit demo for TimeSFormer. | |
Number of CPU(s): {multiprocessing.cpu_count()} | |
""" | |
) | |
model_name = st.selectbox( | |
"model_name", | |
( | |
"facebook/timesformer-base-finetuned-k400", | |
"facebook/timesformer-base-finetuned-k600", | |
"facebook/timesformer-base-finetuned-ssv2", | |
"facebook/timesformer-hr-finetuned-k600", | |
"facebook/timesformer-hr-finetuned-k400", | |
"facebook/timesformer-hr-finetuned-ssv2", | |
"fcakyon/timesformer-large-finetuned-k400", | |
"fcakyon/timesformer-large-finetuned-k600", | |
), | |
) | |
feature_extractor, model = load_model(model_name) | |
frames_per_video = get_frames_per_video(model_name) | |
st.info(f"Frames per video: {frames_per_video}") | |
img_container = ImgContainer(frames_per_video) | |
ctx = st.session_state.ctx = webrtc_streamer( | |
key="snapshot", | |
video_frame_callback=video_frame_callback, | |
rtc_configuration=rtc_configuration, | |
) | |
if img_container.rs is not None: | |
st.dataframe(img_container.rs) | |