Spaces:
Running
on
Zero
Running
on
Zero
import math | |
import numpy as np | |
import skvideo.io | |
from PIL import Image | |
# Shifts src_tf dim to dest dim | |
# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) | |
def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): | |
n_dims = len(x.shape) | |
if src_dim < 0: | |
src_dim = n_dims + src_dim | |
if dest_dim < 0: | |
dest_dim = n_dims + dest_dim | |
assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims | |
dims = list(range(n_dims)) | |
del dims[src_dim] | |
permutation = [] | |
ctr = 0 | |
for i in range(n_dims): | |
if i == dest_dim: | |
permutation.append(src_dim) | |
else: | |
permutation.append(dims[ctr]) | |
ctr += 1 | |
x = x.permute(permutation) | |
if make_contiguous: | |
x = x.contiguous() | |
return x | |
# reshapes tensor start from dim i (inclusive) | |
# to dim j (exclusive) to the desired shape | |
# e.g. if x.shape = (b, thw, c) then | |
# view_range(x, 1, 2, (t, h, w)) returns | |
# x of shape (b, t, h, w, c) | |
def view_range(x, i, j, shape): | |
shape = tuple(shape) | |
n_dims = len(x.shape) | |
if i < 0: | |
i = n_dims + i | |
if j is None: | |
j = n_dims | |
elif j < 0: | |
j = n_dims + j | |
assert 0 <= i < j <= n_dims | |
x_shape = x.shape | |
target_shape = x_shape[:i] + shape + x_shape[j:] | |
return x.view(target_shape) | |
def tensor_slice(x, begin, size): | |
assert all([b >= 0 for b in begin]) | |
size = [l - b if s == -1 else s | |
for s, b, l in zip(size, begin, x.shape)] | |
assert all([s >= 0 for s in size]) | |
slices = [slice(b, b + s) for b, s in zip(begin, size)] | |
return x[slices] | |
def save_video_grid(video, fname, nrow=None, fps=5): | |
b, c, t, h, w = video.shape | |
video = video.permute(0, 2, 3, 4, 1) | |
video = (video.cpu().numpy() * 255).astype('uint8') | |
if nrow is None: | |
nrow = math.ceil(math.sqrt(b)) | |
ncol = math.ceil(b / nrow) | |
padding = 1 | |
video_grid = np.zeros((t, (padding + h) * nrow + padding, | |
(padding + w) * ncol + padding, c), dtype='uint8') | |
for i in range(b): | |
r = i // ncol | |
c = i % ncol | |
start_r = (padding + h) * r | |
start_c = (padding + w) * c | |
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] | |
skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)}) | |
def save_gif_grid(video, file_name, nrow=None, fps=5): | |
b, c, t, h, w = video.shape | |
video = video.permute(0, 2, 3, 4, 1) | |
video = (video.cpu().numpy() * 255).astype('uint8') | |
if nrow is None: | |
nrow = math.ceil(math.sqrt(b)) | |
ncol = math.ceil(b / nrow) | |
padding = 1 | |
video_grid = np.zeros((t, (padding + h) * nrow + padding, | |
(padding + w) * ncol + padding, c), dtype='uint8') | |
for i in range(b): | |
r = i // ncol | |
c = i % ncol | |
start_r = (padding + h) * r | |
start_c = (padding + w) * c | |
video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] | |
images = [] | |
for frame in video_grid: | |
images.append(Image.fromarray(frame)) | |
# Save the first image and append the rest of the images as frames in the GIF | |
images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0) | |
# The 'duration' parameter defines the display time for each frame in milliseconds | |
# The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop) | |