import math import numpy as np import skvideo.io from PIL import Image # Shifts src_tf dim to dest dim # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): n_dims = len(x.shape) if src_dim < 0: src_dim = n_dims + src_dim if dest_dim < 0: dest_dim = n_dims + dest_dim assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims dims = list(range(n_dims)) del dims[src_dim] permutation = [] ctr = 0 for i in range(n_dims): if i == dest_dim: permutation.append(src_dim) else: permutation.append(dims[ctr]) ctr += 1 x = x.permute(permutation) if make_contiguous: x = x.contiguous() return x # reshapes tensor start from dim i (inclusive) # to dim j (exclusive) to the desired shape # e.g. if x.shape = (b, thw, c) then # view_range(x, 1, 2, (t, h, w)) returns # x of shape (b, t, h, w, c) def view_range(x, i, j, shape): shape = tuple(shape) n_dims = len(x.shape) if i < 0: i = n_dims + i if j is None: j = n_dims elif j < 0: j = n_dims + j assert 0 <= i < j <= n_dims x_shape = x.shape target_shape = x_shape[:i] + shape + x_shape[j:] return x.view(target_shape) def tensor_slice(x, begin, size): assert all([b >= 0 for b in begin]) size = [l - b if s == -1 else s for s, b, l in zip(size, begin, x.shape)] assert all([s >= 0 for s in size]) slices = [slice(b, b + s) for b, s in zip(begin, size)] return x[slices] def save_video_grid(video, fname, nrow=None, fps=5): b, c, t, h, w = video.shape video = video.permute(0, 2, 3, 4, 1) video = (video.cpu().numpy() * 255).astype('uint8') if nrow is None: nrow = math.ceil(math.sqrt(b)) ncol = math.ceil(b / nrow) padding = 1 video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype='uint8') for i in range(b): r = i // ncol c = i % ncol start_r = (padding + h) * r start_c = (padding + w) * c video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)}) def save_gif_grid(video, file_name, nrow=None, fps=5): b, c, t, h, w = video.shape video = video.permute(0, 2, 3, 4, 1) video = (video.cpu().numpy() * 255).astype('uint8') if nrow is None: nrow = math.ceil(math.sqrt(b)) ncol = math.ceil(b / nrow) padding = 1 video_grid = np.zeros((t, (padding + h) * nrow + padding, (padding + w) * ncol + padding, c), dtype='uint8') for i in range(b): r = i // ncol c = i % ncol start_r = (padding + h) * r start_c = (padding + w) * c video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] images = [] for frame in video_grid: images.append(Image.fromarray(frame)) # Save the first image and append the rest of the images as frames in the GIF images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0) # The 'duration' parameter defines the display time for each frame in milliseconds # The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop)