import os.path as osp import random import cv2 import decord import numpy as np import skvideo.io import torch import torchvision from decord import VideoReader, cpu, gpu from tqdm import tqdm random.seed(42) decord.bridge.set_bridge("torch") def get_spatial_fragments( video, fragments_h=7, fragments_w=7, fsize_h=32, fsize_w=32, aligned=32, nfrags=1, random=False, fallback_type="upsample", ): size_h = fragments_h * fsize_h size_w = fragments_w * fsize_w ## situation for images if video.shape[1] == 1: aligned = 1 dur_t, res_h, res_w = video.shape[-3:] ratio = min(res_h / size_h, res_w / size_w) if fallback_type == "upsample" and ratio < 1: ovideo = video video = torch.nn.functional.interpolate( video / 255.0, scale_factor=1 / ratio, mode="bilinear" ) video = (video * 255.0).type_as(ovideo) assert dur_t % aligned == 0, "Please provide match vclip and align index" size = size_h, size_w ## make sure that sampling will not run out of the picture hgrids = torch.LongTensor( [min(res_h // fragments_h * i, res_h - fsize_h) for i in range(fragments_h)] ) wgrids = torch.LongTensor( [min(res_w // fragments_w * i, res_w - fsize_w) for i in range(fragments_w)] ) hlength, wlength = res_h // fragments_h, res_w // fragments_w if random: print("This part is deprecated. Please remind that.") if res_h > fsize_h: rnd_h = torch.randint( res_h - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) ) else: rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() if res_w > fsize_w: rnd_w = torch.randint( res_w - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) ) else: rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() else: if hlength > fsize_h: rnd_h = torch.randint( hlength - fsize_h, (len(hgrids), len(wgrids), dur_t // aligned) ) else: rnd_h = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() if wlength > fsize_w: rnd_w = torch.randint( wlength - fsize_w, (len(hgrids), len(wgrids), dur_t // aligned) ) else: rnd_w = torch.zeros((len(hgrids), len(wgrids), dur_t // aligned)).int() target_video = torch.zeros(video.shape[:-2] + size).to(video.device) # target_videos = [] for i, hs in enumerate(hgrids): for j, ws in enumerate(wgrids): for t in range(dur_t // aligned): t_s, t_e = t * aligned, (t + 1) * aligned h_s, h_e = i * fsize_h, (i + 1) * fsize_h w_s, w_e = j * fsize_w, (j + 1) * fsize_w if random: h_so, h_eo = rnd_h[i][j][t], rnd_h[i][j][t] + fsize_h w_so, w_eo = rnd_w[i][j][t], rnd_w[i][j][t] + fsize_w else: h_so, h_eo = hs + rnd_h[i][j][t], hs + rnd_h[i][j][t] + fsize_h w_so, w_eo = ws + rnd_w[i][j][t], ws + rnd_w[i][j][t] + fsize_w target_video[:, t_s:t_e, h_s:h_e, w_s:w_e] = video[ :, t_s:t_e, h_so:h_eo, w_so:w_eo ] # target_videos.append(video[:,t_s:t_e,h_so:h_eo,w_so:w_eo]) # target_video = torch.stack(target_videos, 0).reshape((dur_t // aligned, fragments, fragments,) + target_videos[0].shape).permute(3,0,4,1,5,2,6) # target_video = target_video.reshape((-1, dur_t,) + size) ## Splicing Fragments return target_video class FragmentSampleFrames: def __init__(self, fsize_t, fragments_t, frame_interval=1, num_clips=1): self.fragments_t = fragments_t self.fsize_t = fsize_t self.size_t = fragments_t * fsize_t self.frame_interval = frame_interval self.num_clips = num_clips def get_frame_indices(self, num_frames): tgrids = np.array( [num_frames // self.fragments_t * i for i in range(self.fragments_t)], dtype=np.int32, ) tlength = num_frames // self.fragments_t if tlength > self.fsize_t * self.frame_interval: rnd_t = np.random.randint( 0, tlength - self.fsize_t * self.frame_interval, size=len(tgrids) ) else: rnd_t = np.zeros(len(tgrids), dtype=np.int32) ranges_t = ( np.arange(self.fsize_t)[None, :] * self.frame_interval + rnd_t[:, None] + tgrids[:, None] ) return np.concatenate(ranges_t) def __call__(self, total_frames, train=False, start_index=0): frame_inds = [] for i in range(self.num_clips): frame_inds += [self.get_frame_indices(total_frames)] frame_inds = np.concatenate(frame_inds) frame_inds = np.mod(frame_inds + start_index, total_frames) return frame_inds class SampleFrames: def __init__(self, clip_len, frame_interval=1, num_clips=1): self.clip_len = clip_len self.frame_interval = frame_interval self.num_clips = num_clips def _get_train_clips(self, num_frames): """Get clip offsets in train mode. It will calculate the average interval for selected frames, and randomly shift them within offsets between [0, avg_interval]. If the total number of frames is smaller than clips num or origin frames length, it will return all zero indices. Args: num_frames (int): Total number of frame in the video. Returns: np.ndarray: Sampled frame indices in train mode. """ ori_clip_len = self.clip_len * self.frame_interval avg_interval = (num_frames - ori_clip_len + 1) // self.num_clips if avg_interval > 0: base_offsets = np.arange(self.num_clips) * avg_interval clip_offsets = base_offsets + np.random.randint( avg_interval, size=self.num_clips ) elif num_frames > max(self.num_clips, ori_clip_len): clip_offsets = np.sort( np.random.randint(num_frames - ori_clip_len + 1, size=self.num_clips) ) elif avg_interval == 0: ratio = (num_frames - ori_clip_len + 1.0) / self.num_clips clip_offsets = np.around(np.arange(self.num_clips) * ratio) else: clip_offsets = np.zeros((self.num_clips,), dtype=np.int) return clip_offsets def _get_test_clips(self, num_frames, start_index=0): """Get clip offsets in test mode. Calculate the average interval for selected frames, and shift them fixedly by avg_interval/2. Args: num_frames (int): Total number of frame in the video. Returns: np.ndarray: Sampled frame indices in test mode. """ ori_clip_len = self.clip_len * self.frame_interval avg_interval = (num_frames - ori_clip_len + 1) / float(self.num_clips) if num_frames > ori_clip_len - 1: base_offsets = np.arange(self.num_clips) * avg_interval clip_offsets = (base_offsets + avg_interval / 2.0).astype(np.int32) else: clip_offsets = np.zeros((self.num_clips,), dtype=np.int32) return clip_offsets def __call__(self, total_frames, train=False, start_index=0): """Perform the SampleFrames loading. Args: results (dict): The resulting dict to be modified and passed to the next transform in pipeline. """ if train: clip_offsets = self._get_train_clips(total_frames) else: clip_offsets = self._get_test_clips(total_frames) frame_inds = ( clip_offsets[:, None] + np.arange(self.clip_len)[None, :] * self.frame_interval ) frame_inds = np.concatenate(frame_inds) frame_inds = frame_inds.reshape((-1, self.clip_len)) frame_inds = np.mod(frame_inds, total_frames) frame_inds = np.concatenate(frame_inds) + start_index return frame_inds.astype(np.int32) class FastVQAPlusPlusDataset(torch.utils.data.Dataset): def __init__( self, ann_file, data_prefix, frame_interval=2, aligned=32, fragments=(8, 8, 8), fsize=(4, 32, 32), num_clips=1, nfrags=1, cache_in_memory=False, phase="test", fallback_type="oversample", ): """ Fragments. args: fragments: G_f as in the paper. fsize: S_f as in the paper. nfrags: number of samples (spatially) as in the paper. num_clips: number of samples (temporally) as in the paper. """ self.ann_file = ann_file self.data_prefix = data_prefix self.frame_interval = frame_interval self.num_clips = num_clips self.fragments = fragments self.fsize = fsize self.nfrags = nfrags self.clip_len = fragments[0] * fsize[0] self.aligned = aligned self.fallback_type = fallback_type self.sampler = FragmentSampleFrames( fsize[0], fragments[0], frame_interval, num_clips ) self.video_infos = [] self.phase = phase self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) self.std = torch.FloatTensor([58.395, 57.12, 57.375]) if isinstance(self.ann_file, list): self.video_infos = self.ann_file else: with open(self.ann_file, "r") as fin: for line in fin: line_split = line.strip().split(",") filename, _, _, label = line_split label = float(label) filename = osp.join(self.data_prefix, filename) self.video_infos.append(dict(filename=filename, label=label)) if cache_in_memory: self.cache = {} for i in tqdm(range(len(self)), desc="Caching fragments"): self.cache[i] = self.__getitem__(i, tocache=True) else: self.cache = None def __getitem__( self, index, tocache=False, need_original_frames=False, ): if tocache or self.cache is None: fx, fy = self.fragments[1:] fsx, fsy = self.fsize[1:] video_info = self.video_infos[index] filename = video_info["filename"] label = video_info["label"] if filename.endswith(".yuv"): video = skvideo.io.vread( filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} ) frame_inds = self.sampler(video.shape[0], self.phase == "train") imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] else: vreader = VideoReader(filename) frame_inds = self.sampler(len(vreader), self.phase == "train") frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} imgs = [frame_dict[idx] for idx in frame_inds] img_shape = imgs[0].shape video = torch.stack(imgs, 0) video = video.permute(3, 0, 1, 2) if self.nfrags == 1: vfrag = get_spatial_fragments( video, fx, fy, fsx, fsy, aligned=self.aligned, fallback_type=self.fallback_type, ) else: vfrag = get_spatial_fragments( video, fx, fy, fsx, fsy, aligned=self.aligned, fallback_type=self.fallback_type, ) for i in range(1, self.nfrags): vfrag = torch.cat( ( vfrag, get_spatial_fragments( video, fragments, fx, fy, fsx, fsy, aligned=self.aligned, fallback_type=self.fallback_type, ), ), 1, ) if tocache: return (vfrag, frame_inds, label, img_shape) else: vfrag, frame_inds, label, img_shape = self.cache[index] vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) data = { "video": vfrag.reshape( (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] ).transpose( 0, 1 ), # B, V, T, C, H, W "frame_inds": frame_inds, "gt_label": label, "original_shape": img_shape, } if need_original_frames: data["original_video"] = video.reshape( (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] ).transpose(0, 1) return data def __len__(self): return len(self.video_infos) class FragmentVideoDataset(torch.utils.data.Dataset): def __init__( self, ann_file, data_prefix, clip_len=32, frame_interval=2, num_clips=4, aligned=32, fragments=7, fsize=32, nfrags=1, cache_in_memory=False, phase="test", ): """ Fragments. args: fragments: G_f as in the paper. fsize: S_f as in the paper. nfrags: number of samples as in the paper. """ self.ann_file = ann_file self.data_prefix = data_prefix self.clip_len = clip_len self.frame_interval = frame_interval self.num_clips = num_clips self.fragments = fragments self.fsize = fsize self.nfrags = nfrags self.aligned = aligned self.sampler = SampleFrames(clip_len, frame_interval, num_clips) self.video_infos = [] self.phase = phase self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) self.std = torch.FloatTensor([58.395, 57.12, 57.375]) if isinstance(self.ann_file, list): self.video_infos = self.ann_file else: with open(self.ann_file, "r") as fin: for line in fin: line_split = line.strip().split(",") filename, _, _, label = line_split label = float(label) filename = osp.join(self.data_prefix, filename) self.video_infos.append(dict(filename=filename, label=label)) if cache_in_memory: self.cache = {} for i in tqdm(range(len(self)), desc="Caching fragments"): self.cache[i] = self.__getitem__(i, tocache=True) else: self.cache = None def __getitem__( self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False, ): if tocache or self.cache is None: if fragments == -1: fragments = self.fragments if fsize == -1: fsize = self.fsize video_info = self.video_infos[index] filename = video_info["filename"] label = video_info["label"] if filename.endswith(".yuv"): video = skvideo.io.vread( filename, 1080, 1920, inputdict={"-pix_fmt": "yuvj420p"} ) frame_inds = self.sampler(video.shape[0], self.phase == "train") imgs = [torch.from_numpy(video[idx]) for idx in frame_inds] else: vreader = VideoReader(filename) frame_inds = self.sampler(len(vreader), self.phase == "train") frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} imgs = [frame_dict[idx] for idx in frame_inds] img_shape = imgs[0].shape video = torch.stack(imgs, 0) video = video.permute(3, 0, 1, 2) if self.nfrags == 1: vfrag = get_spatial_fragments( video, fragments, fragments, fsize, fsize, aligned=self.aligned ) else: vfrag = get_spatial_fragments( video, fragments, fragments, fsize, fsize, aligned=self.aligned ) for i in range(1, self.nfrags): vfrag = torch.cat( ( vfrag, get_spatial_fragments( video, fragments, fragments, fsize, fsize, aligned=self.aligned, ), ), 1, ) if tocache: return (vfrag, frame_inds, label, img_shape) else: vfrag, frame_inds, label, img_shape = self.cache[index] vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) data = { "video": vfrag.reshape( (-1, self.nfrags * self.num_clips, self.clip_len) + vfrag.shape[2:] ).transpose( 0, 1 ), # B, V, T, C, H, W "frame_inds": frame_inds, "gt_label": label, "original_shape": img_shape, } if need_original_frames: data["original_video"] = video.reshape( (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] ).transpose(0, 1) return data def __len__(self): return len(self.video_infos) class ResizedVideoDataset(torch.utils.data.Dataset): def __init__( self, ann_file, data_prefix, clip_len=32, frame_interval=2, num_clips=4, aligned=32, size=224, cache_in_memory=False, phase="test", ): """ Using resizing. """ self.ann_file = ann_file self.data_prefix = data_prefix self.clip_len = clip_len self.frame_interval = frame_interval self.num_clips = num_clips self.size = size self.aligned = aligned self.sampler = SampleFrames(clip_len, frame_interval, num_clips) self.video_infos = [] self.phase = phase self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) self.std = torch.FloatTensor([58.395, 57.12, 57.375]) if isinstance(self.ann_file, list): self.video_infos = self.ann_file else: with open(self.ann_file, "r") as fin: for line in fin: line_split = line.strip().split(",") filename, _, _, label = line_split label = float(label) filename = osp.join(self.data_prefix, filename) self.video_infos.append(dict(filename=filename, label=label)) if cache_in_memory: self.cache = {} for i in tqdm(range(len(self)), desc="Caching resized videos"): self.cache[i] = self.__getitem__(i, tocache=True) else: self.cache = None def __getitem__(self, index, tocache=False, need_original_frames=False): if tocache or self.cache is None: video_info = self.video_infos[index] filename = video_info["filename"] label = video_info["label"] vreader = VideoReader(filename) frame_inds = self.sampler(len(vreader), self.phase == "train") frame_dict = {idx: vreader[idx] for idx in np.unique(frame_inds)} imgs = [frame_dict[idx] for idx in frame_inds] img_shape = imgs[0].shape video = torch.stack(imgs, 0) video = video.permute(3, 0, 1, 2) video = torch.nn.functional.interpolate(video, size=(self.size, self.size)) if tocache: return (vfrag, frame_inds, label, img_shape) else: vfrag, frame_inds, label, img_shape = self.cache[index] vfrag = ((vfrag.permute(1, 2, 3, 0) - self.mean) / self.std).permute(3, 0, 1, 2) data = { "video": vfrag.reshape( (-1, self.num_clips, self.clip_len) + vfrag.shape[2:] ).transpose( 0, 1 ), # B, V, T, C, H, W "frame_inds": frame_inds, "gt_label": label, "original_shape": img_shape, } if need_original_frames: data["original_video"] = video.reshape( (-1, self.nfrags * self.num_clips, self.clip_len) + video.shape[2:] ).transpose(0, 1) return data def __len__(self): return len(self.video_infos) class CroppedVideoDataset(FragmentVideoDataset): def __init__( self, ann_file, data_prefix, clip_len=32, frame_interval=2, num_clips=4, aligned=32, size=224, ncrops=1, cache_in_memory=False, phase="test", ): """ Regard Cropping as a special case for Fragments in Grid 1*1. """ super().__init__( ann_file, data_prefix, clip_len=clip_len, frame_interval=frame_interval, num_clips=num_clips, aligned=aligned, fragments=1, fsize=224, nfrags=ncrops, cache_in_memory=cache_in_memory, phase=phase, ) class FragmentImageDataset(torch.utils.data.Dataset): def __init__( self, ann_file, data_prefix, fragments=7, fsize=32, nfrags=1, cache_in_memory=False, phase="test", ): self.ann_file = ann_file self.data_prefix = data_prefix self.fragments = fragments self.fsize = fsize self.nfrags = nfrags self.image_infos = [] self.phase = phase self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) self.std = torch.FloatTensor([58.395, 57.12, 57.375]) if isinstance(self.ann_file, list): self.image_infos = self.ann_file else: with open(self.ann_file, "r") as fin: for line in fin: line_split = line.strip().split(",") filename, _, _, label = line_split label = float(label) filename = osp.join(self.data_prefix, filename) self.image_infos.append(dict(filename=filename, label=label)) if cache_in_memory: self.cache = {} for i in tqdm(range(len(self)), desc="Caching fragments"): self.cache[i] = self.__getitem__(i, tocache=True) else: self.cache = None def __getitem__( self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False ): if tocache or self.cache is None: if fragments == -1: fragments = self.fragments if fsize == -1: fsize = self.fsize image_info = self.image_infos[index] filename = image_info["filename"] label = image_info["label"] try: img = torchvision.io.read_image(filename) except: img = cv2.imread(filename) img = torch.from_numpy(img[:, :, [2, 1, 0]]).permute(2, 0, 1) img_shape = img.shape[1:] image = img.unsqueeze(1) if self.nfrags == 1: ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) else: ifrag = get_spatial_fragments(image, fragments, fragments, fsize, fsize) for i in range(1, self.nfrags): ifrag = torch.cat( ( ifrag, get_spatial_fragments( image, fragments, fragments, fsize, fsize ), ), 1, ) if tocache: return (ifrag, label, img_shape) else: ifrag, label, img_shape = self.cache[index] if self.nfrags == 1: ifrag = ( ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) .squeeze(0) .permute(2, 0, 1) ) else: ### During testing, one image as a batch ifrag = ( ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) .squeeze(0) .permute(0, 3, 1, 2) ) data = { "image": ifrag, "gt_label": label, "original_shape": img_shape, "name": filename, } if need_original_frames: data["original_image"] = image.squeeze(1) return data def __len__(self): return len(self.image_infos) class ResizedImageDataset(torch.utils.data.Dataset): def __init__( self, ann_file, data_prefix, size=224, cache_in_memory=False, phase="test", ): self.ann_file = ann_file self.data_prefix = data_prefix self.size = size self.image_infos = [] self.phase = phase self.mean = torch.FloatTensor([123.675, 116.28, 103.53]) self.std = torch.FloatTensor([58.395, 57.12, 57.375]) if isinstance(self.ann_file, list): self.image_infos = self.ann_file else: with open(self.ann_file, "r") as fin: for line in fin: line_split = line.strip().split(",") filename, _, _, label = line_split label = float(label) filename = osp.join(self.data_prefix, filename) self.image_infos.append(dict(filename=filename, label=label)) if cache_in_memory: self.cache = {} for i in tqdm(range(len(self)), desc="Caching fragments"): self.cache[i] = self.__getitem__(i, tocache=True) else: self.cache = None def __getitem__( self, index, fragments=-1, fsize=-1, tocache=False, need_original_frames=False ): if tocache or self.cache is None: if fragments == -1: fragments = self.fragments if fsize == -1: fsize = self.fsize image_info = self.image_infos[index] filename = image_info["filename"] label = image_info["label"] img = torchvision.io.read_image(filename) img_shape = img.shape[1:] image = img.unsqueeze(1) if self.nfrags == 1: ifrag = get_spatial_fragments(image, fragments, fsize) else: ifrag = get_spatial_fragments(image, fragments, fsize) for i in range(1, self.nfrags): ifrag = torch.cat( (ifrag, get_spatial_fragments(image, fragments, fsize)), 1 ) if tocache: return (ifrag, label, img_shape) else: ifrag, label, img_shape = self.cache[index] ifrag = ( ((ifrag.permute(1, 2, 3, 0) - self.mean) / self.std) .squeeze(0) .permute(2, 0, 1) ) data = { "image": ifrag, "gt_label": label, "original_shape": img_shape, } if need_original_frames: data["original_image"] = image.squeeze(1) return data def __len__(self): return len(self.image_infos) class CroppedImageDataset(FragmentImageDataset): def __init__( self, ann_file, data_prefix, size=224, ncrops=1, cache_in_memory=False, phase="test", ): """ Regard Cropping as a special case for Fragments in Grid 1*1. """ super().__init__( ann_file, data_prefix, fragments=1, fsize=224, nfrags=ncrops, cache_in_memory=cache_in_memory, phase=phase, )