|
import numpy as np
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import roma
|
|
from kiui.op import safe_normalize
|
|
|
|
def get_rays(pose, h, w, fovy, opengl=True):
|
|
|
|
x, y = torch.meshgrid(
|
|
torch.arange(w, device=pose.device),
|
|
torch.arange(h, device=pose.device),
|
|
indexing="xy",
|
|
)
|
|
x = x.flatten()
|
|
y = y.flatten()
|
|
|
|
cx = w * 0.5
|
|
cy = h * 0.5
|
|
|
|
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))
|
|
|
|
camera_dirs = F.pad(
|
|
torch.stack(
|
|
[
|
|
(x - cx + 0.5) / focal,
|
|
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
|
|
],
|
|
dim=-1,
|
|
),
|
|
(0, 1),
|
|
value=(-1.0 if opengl else 1.0),
|
|
)
|
|
|
|
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)
|
|
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d)
|
|
|
|
rays_o = rays_o.view(h, w, 3)
|
|
rays_d = safe_normalize(rays_d).view(h, w, 3)
|
|
|
|
return rays_o, rays_d
|
|
|
|
def orbit_camera_jitter(poses, strength=0.1):
|
|
|
|
|
|
|
|
B = poses.shape[0]
|
|
rotvec_x = poses[:, :3, 1] * strength * np.pi * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
|
rotvec_y = poses[:, :3, 0] * strength * np.pi / 2 * (torch.rand(B, 1, device=poses.device) * 2 - 1)
|
|
|
|
rot = roma.rotvec_to_rotmat(rotvec_x) @ roma.rotvec_to_rotmat(rotvec_y)
|
|
R = rot @ poses[:, :3, :3]
|
|
T = rot @ poses[:, :3, 3:]
|
|
|
|
new_poses = poses.clone()
|
|
new_poses[:, :3, :3] = R
|
|
new_poses[:, :3, 3:] = T
|
|
|
|
return new_poses
|
|
|
|
def grid_distortion(images, strength=0.5):
|
|
|
|
|
|
|
|
|
|
B, C, H, W = images.shape
|
|
|
|
num_steps = np.random.randint(8, 17)
|
|
grid_steps = torch.linspace(-1, 1, num_steps)
|
|
|
|
|
|
grids = []
|
|
for b in range(B):
|
|
|
|
x_steps = torch.linspace(0, 1, num_steps)
|
|
x_steps = (x_steps + strength * (torch.rand_like(x_steps) - 0.5) / (num_steps - 1)).clamp(0, 1)
|
|
x_steps = (x_steps * W).long()
|
|
x_steps[0] = 0
|
|
x_steps[-1] = W
|
|
xs = []
|
|
for i in range(num_steps - 1):
|
|
xs.append(torch.linspace(grid_steps[i], grid_steps[i + 1], x_steps[i + 1] - x_steps[i]))
|
|
xs = torch.cat(xs, dim=0)
|
|
|
|
y_steps = torch.linspace(0, 1, num_steps)
|
|
y_steps = (y_steps + strength * (torch.rand_like(y_steps) - 0.5) / (num_steps - 1)).clamp(0, 1)
|
|
y_steps = (y_steps * H).long()
|
|
y_steps[0] = 0
|
|
y_steps[-1] = H
|
|
ys = []
|
|
for i in range(num_steps - 1):
|
|
ys.append(torch.linspace(grid_steps[i], grid_steps[i + 1], y_steps[i + 1] - y_steps[i]))
|
|
ys = torch.cat(ys, dim=0)
|
|
|
|
|
|
grid_x, grid_y = torch.meshgrid(xs, ys, indexing='xy')
|
|
grid = torch.stack([grid_x, grid_y], dim=-1)
|
|
|
|
grids.append(grid)
|
|
|
|
grids = torch.stack(grids, dim=0).to(images.device)
|
|
|
|
|
|
images = F.grid_sample(images, grids, align_corners=False)
|
|
|
|
return images
|
|
|
|
|