MBase / app.py
MNGames's picture
Update app.py
c9b0a28 verified
raw
history blame
3.76 kB
import gradio as gr
import torch
import cv2
import numpy as np
from transformers import VideoMAEForVideoClassification, VideoMAEImageProcessor
# Model IDs for video classification (UCF101 subset)
classification_model_id = "MCG-NJU/videomae-base"
# Object detection model (you can replace this with a more accurate one if needed)
object_detection_model = "yolov5s"
# Parameters for frame extraction
TARGET_FRAME_COUNT = 16
FRAME_SIZE = (224, 224) # Expected frame size for the model
def analyze_video(video):
# Extract key frames from the video using OpenCV
frames = extract_key_frames(video)
# Load classification model and image processor
classification_model = VideoMAEForVideoClassification.from_pretrained(classification_model_id)
processor = VideoMAEImageProcessor.from_pretrained(classification_model_id)
# Prepare frames for the classification model
inputs = processor(images=frames, return_tensors="pt")
# Make predictions using the classification model
with torch.no_grad():
outputs = classification_model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
# Object detection and tracking (ball and baseman)
object_detection_results = []
for frame in frames:
ball_position = detect_object(frame, "ball")
baseman_position = detect_object(frame, "baseman")
object_detection_results.append((ball_position, baseman_position))
# Analyze predictions and object detection results
analysis_results = []
for prediction, (ball_position, baseman_position) in zip(predictions, object_detection_results):
result = analyze_frame(prediction.item(), ball_position, baseman_position)
analysis_results.append(result)
# Aggregate analysis results
final_result = aggregate_results(analysis_results)
return final_result
def extract_key_frames(video):
cap = cv2.VideoCapture(video)
frames = []
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
interval = max(1, frame_count // TARGET_FRAME_COUNT)
for i in range(frame_count):
ret, frame = cap.read()
if ret and i % interval == 0: # Extract frames at regular intervals
frame = cv2.resize(frame, FRAME_SIZE) # Resize frame
frames.append(frame)
cap.release()
return frames
def detect_object(frame, object_type):
# Placeholder function for object detection (replace with actual implementation)
# Here, we assume that the object is detected at the center of the frame
h, w, _ = frame.shape
if object_type == "ball":
return (w // 2, h // 2) # Return center coordinates for the ball
elif object_type == "baseman":
return (w // 2, h // 2) # Return center coordinates for the baseman
else:
return None
def analyze_frame(prediction, ball_position, baseman_position):
# Placeholder function for analyzing a single frame
# You can replace this with actual logic based on your requirements
action_labels = {
0: "running",
1: "sliding",
2: "jumping",
# Add more labels as necessary
}
action = action_labels.get(prediction, "unknown")
return {"action": action, "ball_position": ball_position, "baseman_position": baseman_position}
def aggregate_results(results):
# Placeholder function for aggregating analysis results
# You can implement this based on your specific requirements
return results
# Gradio interface
interface = gr.Interface(
fn=analyze_video,
inputs="video",
outputs="text",
title="Baseball Play Analysis",
description="Upload a video of a baseball play to analyze.",
)
interface.launch()