from transformers import set_seed from tqdm.auto import trange from PIL import Image import numpy as np import random import utils import torch CONFIG_SPEC = [ ("General", [ ("text", "A cloud at dawn", str), ("iterations", 5000, (0, 7500)), ("seed", 12, int), ("show_every", 10, int), ]), ("Rendering", [ ("w", 224, [224, 252]), ("h", 224, [224, 252]), ("showoff", 5000, (0, 10000)), ("turns", 4, int), ("focal_length", 0.1, float), ("plane_width", 0.1, float), ("shade_strength", 0.25, float), ("gamma", 0.5, float), ("max_depth", 7, float), ("offset", 5, float), ("offset_random", 0.75, float), ("xyz_random", 0.25, float), ("altitude_range", 0.3, float), ("augments", 3, int), ]), ("Optimization", [ ("epochs", 6, int), ("lr", 0.6, float), #@markdown CLIP loss type, might improve the results ("loss_type", "spherical", ["spherical", "cosine"]), #@markdown CLIP loss weight ("clip_weight", 1.0, float), #@param {type: "number"} ]), ("Elements", [ ("num_objects", 256, int), #@markdown Number of dimensions. 0 is for point clouds (default), 1 will make #@markdown strokes, 2 will make planes, 3 produces little cubes ("ndim", 0, [0, 1, 2, 3]), #@param {type: "integer"} #@markdown Opacity scale: ("min_opacity", 1e-4, float), #@param {type: "number"} ("max_opacity", 1.0, float), #@param {type: "number"} ("log_opacity", False, bool), #@param {type: "boolean"} ("min_radius", 0.030, float), ("max_radius", 0.150, float), ("log_radius", False, bool), # TODO dynamically decide bezier_res #@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points ("bezier_res", 8, int), #@param {type: "integer"} #@markdown Maximum scale of parameters: position, velocity, acceleration ("pos_scale", 0.4, float), #@param {type: "number"} ("vel_scale", 0.15, float), #@param {type: "number"} ("acc_scale", 0.15, float), #@param {type: "number"} #@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale. ("scale", 1, float), #@param {type: "number"} ]), ] # TODO: one day separate the config into multiple parts and split this megaobject into multiple objects # 2022/08/09: halfway done class PulsarCLIP(object): def __init__(self, args): args = DotDict(**args) set_seed(args.seed) self.args = args self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu") # Defer the import so that we can import `pulsar_clip` and then install `pytorch3d` import pytorch3d.renderer.points.pulsar as ps self.ndim = int(self.args.ndim) self.renderer = ps.Renderer(self.args.w, self.args.h, self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device) self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device)) self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device)) self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr), dict(params=[self.bezier_pos], lr=1e-1 * args.lr), dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr), ]) self.model_clip, self.preprocess_clip = utils.load_clip() self.model_clip.visual.requires_grad_(False) self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, int(self.args.iterations / self.args.augments / self.args.epochs), eta_min=args.lr / 100) import clip self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach() self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1) def get_points(self): if self.ndim > 0: bezier_ts = torch.stack(torch.meshgrid( (torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0 ).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1) def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None): pos_scale = self.args.pos_scale if pos_scale is None else pos_scale vel_scale = self.args.vel_scale if vel_scale is None else vel_scale acc_scale = self.args.acc_scale if acc_scale is None else acc_scale scale = self.args.scale if scale is None else scale if self.ndim == 0: return pos * pos_scale result = 0.0 s = pos.shape[-1] assert s * self.ndim == vel.shape[-1] == acc.shape[-1] # O(dim) sequential lol for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation result = (result + torch.tanh(vel[..., d * s:(d + 1) * s]).view( (-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t + torch.tanh(acc[..., d * s:(d + 1) * s]).view( (-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2)) result = (result * scale + torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s) return result vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc) vert_col = interpolate_3D(self.bezier_col[..., :4], self.bezier_col[..., 4:4 + 4 * self.ndim], self.bezier_col[..., -4 * self.ndim:]) to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat( (1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1]) rescale = lambda x, a, b, is_log=False: (torch.exp(x * np.log(b / a) + np.log(a))) if is_log else x * (b - a) + a return ( vert_pos, torch.sigmoid(vert_col[..., :3]), rescale( torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]), self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius ), rescale(torch.sigmoid(vert_col[..., -1]), self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity)) def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None, xyz_random=None, focal_length=None, plane_width=None): if offset is None: offset = self.args.offset if xyz_random is None: xyz_random = self.args.xyz_random if focal_length is None: focal_length = self.args.focal_length if plane_width is None: plane_width = self.args.plane_width if offset_random is None: offset_random = self.args.offset_random device = self.device offset = offset + np.random.normal() * offset_random * int(use_random) position = torch.tensor([0, 0, -offset], dtype=torch.float) position = utils.rotate_axis(position, altitude, 0) position = utils.rotate_axis(position, angle, 1) position = position + torch.randn(3) * xyz_random * int(use_random) return torch.tensor([position[0], position[1], position[2], altitude, angle, 0, focal_length, plane_width], dtype=torch.float, device=device) def render(self, cam_params=None): if cam_params is None: cam_params = self.camera(0, 0) vert_pos, vert_col, radius, opacity = self.get_points() rgb = self.renderer(vert_pos, vert_col, radius, cam_params, self.args.gamma, self.args.max_depth, opacity=opacity) opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params, self.args.gamma, self.args.max_depth, opacity=opacity) return rgb, opacity def random_view_render(self): angle = random.uniform(0, np.pi * 2) altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2) cam_params = self.camera(angle, altitude) result, alpha = self.render(cam_params) back = torch.zeros_like(result) s = back.shape for j in range(s[-1]): n = random.choice([7, 14, 28]) back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5 result = result * (1 - alpha) + back * alpha return result def generate(self): self.optimizer.zero_grad() try: for i in trange(self.args.iterations + self.args.showoff): if i < self.args.iterations: result = self.random_view_render() img_emb = self.model_clip.encode_image( self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.)) img_emb = torch.nn.functional.normalize(img_emb, dim=-1) if self.args.loss_type == "spherical": clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean() elif self.args.loss_type == "cosine": clip_loss = (1 - img_emb @ self.txt_emb.T).mean() else: raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}") loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types loss.backward() if i % self.args.augments == self.args.augments - 1: self.optimizer.step() self.optimizer.zero_grad() try: self.scheduler.step() except AttributeError: pass if i % self.args.show_every == 0: cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False) img_show, _ = self.render(cam_params) img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8)) yield img except KeyboardInterrupt: pass def save_obj(self, fn): utils.save_obj(self.get_points(), fn) class DotDict(dict): def __getattr__(self, item): return self.__getitem__(item)