Spaces:
Paused
Paused
import torch | |
from transformers import TextStreamer | |
import numpy as np | |
import os | |
import json | |
import torch | |
import numpy as np | |
import base64 | |
from PIL import Image | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
from torchvision.transforms import Compose, Lambda, ToTensor | |
from torchvision import transforms | |
from transformers import ProcessorMixin, BatchEncoding | |
from transformers.image_processing_utils import BatchFeature | |
from pytorchvideo.data.encoded_video import EncodedVideo | |
from torchvision.transforms import Compose, Lambda, ToTensor | |
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo, CenterCropVideo | |
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample | |
def load_frames(frames_dir): | |
results = [] | |
frame_names = os.listdir(frames_dir) | |
frame_names.sort() | |
for frame_name in frame_names: | |
image_path = f"{frames_dir}/{frame_name}" | |
results.append(image_path) | |
return results | |
def sample_frames(frames, num_segments): | |
duration = len(frames) | |
frame_id_array = np.linspace(0, duration-1, num_segments, dtype=int) | |
frame_id_list = frame_id_array.tolist() | |
sampled_frames = [] | |
for frame_idx in frame_id_list: | |
single_frame_path = frames[frame_idx] | |
sampled_frames.append(single_frame_path) | |
return sampled_frames | |
class VideoProcessor: | |
def __init__(self, image_transform): | |
self.image_transform = image_transform | |
def __call__(self, video_path, transform=None, | |
video_decode_backend='opencv', | |
clip_start_sec=0.0, clip_end_sec=None, | |
num_frames=50, **kwargs): | |
if transform is None: transform = self.image_transform | |
if video_decode_backend == 'pytorchvideo': | |
# decord pyav | |
video = EncodedVideo.from_path(video_path, decoder="decord", decode_audio=False) | |
duration = video.duration | |
start_sec = clip_start_sec # secs | |
end_sec = clip_end_sec if clip_end_sec is not None else duration # secs | |
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec) | |
video_outputs = transform(video_data) | |
elif video_decode_backend == 'decord': | |
import decord | |
from decord import VideoReader, cpu | |
decord.bridge.set_bridge('torch') | |
decord_vr = VideoReader(video_path, ctx=cpu(0)) | |
ori_duration = len(decord_vr) | |
# frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) | |
fps_vid = decord_vr.get_avg_fps() | |
valid_duration = min(int(fps_vid * 10), ori_duration) | |
frame_id_list = np.linspace(0, valid_duration-1, num_frames, dtype=int) | |
video_data = decord_vr.get_batch(frame_id_list) | |
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W) | |
video_outputs = transform(video_data) | |
elif video_decode_backend == 'opencv': | |
import cv2 | |
cv2_vr = cv2.VideoCapture(video_path) | |
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) | |
frame_id_list = np.linspace(0, duration-1, num_frames, dtype=int) | |
video_data = [] | |
for frame_idx in frame_id_list: | |
cv2_vr.set(1, frame_idx) | |
_, frame = cv2_vr.read() | |
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
video_data.append(torch.from_numpy(frame).permute(2, 0, 1)) | |
cv2_vr.release() | |
video_data = torch.stack(video_data, dim=1) | |
video_outputs = transform(video_data) | |
elif video_decode_backend == 'frames': | |
# FIXME does not input start and end clip timestamps. Require duration info to deal with. | |
frames = load_frames(video_path) | |
frames = sample_frames(frames, num_frames) | |
to_tensor = ToTensor() | |
video_data = torch.stack([to_tensor(_) for _ in frames]).permute(1, 0, 2, 3) # (T, C, H, W) -> (C, T, H, W) | |
else: | |
raise NameError('video_decode_backend should specify in (pytorchvideo, decord, opencv, frames)') | |