Spaces:
Build error
Build error
from transformers import set_seed | |
from tqdm.auto import trange | |
from PIL import Image | |
import numpy as np | |
import random | |
import utils | |
import torch | |
CONFIG_SPEC = [ | |
("General", [ | |
("text", "A cloud at dawn", str), | |
("iterations", 5000, (0, 7500)), | |
("seed", 12, int), | |
("show_every", 10, int), | |
]), | |
("Rendering", [ | |
("w", 224, [224, 252]), | |
("h", 224, [224, 252]), | |
("showoff", 5000, (0, 10000)), | |
("turns", 4, int), | |
("focal_length", 0.1, float), | |
("plane_width", 0.1, float), | |
("shade_strength", 0.25, float), | |
("gamma", 0.5, float), | |
("max_depth", 7, float), | |
("offset", 5, float), | |
("offset_random", 0.75, float), | |
("xyz_random", 0.25, float), | |
("altitude_range", 0.3, float), | |
("augments", 3, int), | |
]), | |
("Optimization", [ | |
("epochs", 6, int), | |
("lr", 0.6, float), | |
#@markdown CLIP loss type, might improve the results | |
("loss_type", "spherical", ["spherical", "cosine"]), | |
#@markdown CLIP loss weight | |
("clip_weight", 1.0, float), #@param {type: "number"} | |
]), | |
("Elements", [ | |
("num_objects", 256, int), | |
#@markdown Number of dimensions. 0 is for point clouds (default), 1 will make | |
#@markdown strokes, 2 will make planes, 3 produces little cubes | |
("ndim", 0, [0, 1, 2, 3]), #@param {type: "integer"} | |
#@markdown Opacity scale: | |
("min_opacity", 1e-4, float), #@param {type: "number"} | |
("max_opacity", 1.0, float), #@param {type: "number"} | |
("log_opacity", False, bool), #@param {type: "boolean"} | |
("min_radius", 0.030, float), | |
("max_radius", 0.150, float), | |
("log_radius", False, bool), | |
# TODO dynamically decide bezier_res | |
#@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points | |
("bezier_res", 8, int), #@param {type: "integer"} | |
#@markdown Maximum scale of parameters: position, velocity, acceleration | |
("pos_scale", 0.4, float), #@param {type: "number"} | |
("vel_scale", 0.15, float), #@param {type: "number"} | |
("acc_scale", 0.15, float), #@param {type: "number"} | |
#@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale. | |
("scale", 1, float), #@param {type: "number"} | |
]), | |
] | |
# TODO: one day separate the config into multiple parts and split this megaobject into multiple objects | |
# 2022/08/09: halfway done | |
class PulsarCLIP(object): | |
def __init__(self, args): | |
args = DotDict(**args) | |
set_seed(args.seed) | |
self.args = args | |
self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu") | |
# Defer the import so that we can import `pulsar_clip` and then install `pytorch3d` | |
import pytorch3d.renderer.points.pulsar as ps | |
self.ndim = int(self.args.ndim) | |
self.renderer = ps.Renderer(self.args.w, self.args.h, | |
self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device) | |
self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device)) | |
self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) | |
self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device)) | |
self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device)) | |
self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr), | |
dict(params=[self.bezier_pos], lr=1e-1 * args.lr), | |
dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr), | |
]) | |
self.model_clip, self.preprocess_clip = utils.load_clip() | |
self.model_clip.visual.requires_grad_(False) | |
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer, | |
int(self.args.iterations | |
/ self.args.augments | |
/ self.args.epochs), | |
eta_min=args.lr / 100) | |
import clip | |
self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach() | |
self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1) | |
def get_points(self): | |
if self.ndim > 0: | |
bezier_ts = torch.stack(torch.meshgrid( | |
(torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0 | |
).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1) | |
def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None): | |
pos_scale = self.args.pos_scale if pos_scale is None else pos_scale | |
vel_scale = self.args.vel_scale if vel_scale is None else vel_scale | |
acc_scale = self.args.acc_scale if acc_scale is None else acc_scale | |
scale = self.args.scale if scale is None else scale | |
if self.ndim == 0: | |
return pos * pos_scale | |
result = 0.0 | |
s = pos.shape[-1] | |
assert s * self.ndim == vel.shape[-1] == acc.shape[-1] | |
# O(dim) sequential lol | |
for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation | |
result = (result | |
+ torch.tanh(vel[..., d * s:(d + 1) * s]).view( | |
(-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t | |
+ torch.tanh(acc[..., d * s:(d + 1) * s]).view( | |
(-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2)) | |
result = (result * scale | |
+ torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s) | |
return result | |
vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc) | |
vert_col = interpolate_3D(self.bezier_col[..., :4], | |
self.bezier_col[..., 4:4 + 4 * self.ndim], | |
self.bezier_col[..., -4 * self.ndim:]) | |
to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat( | |
(1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1]) | |
rescale = lambda x, a, b, is_log=False: (torch.exp(x | |
* np.log(b / a) | |
+ np.log(a))) if is_log else x * (b - a) + a | |
return ( | |
vert_pos, | |
torch.sigmoid(vert_col[..., :3]), | |
rescale( | |
torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]), | |
self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius | |
), | |
rescale(torch.sigmoid(vert_col[..., -1]), | |
self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity)) | |
def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None, | |
xyz_random=None, focal_length=None, plane_width=None): | |
if offset is None: | |
offset = self.args.offset | |
if xyz_random is None: | |
xyz_random = self.args.xyz_random | |
if focal_length is None: | |
focal_length = self.args.focal_length | |
if plane_width is None: | |
plane_width = self.args.plane_width | |
if offset_random is None: | |
offset_random = self.args.offset_random | |
device = self.device | |
offset = offset + np.random.normal() * offset_random * int(use_random) | |
position = torch.tensor([0, 0, -offset], dtype=torch.float) | |
position = utils.rotate_axis(position, altitude, 0) | |
position = utils.rotate_axis(position, angle, 1) | |
position = position + torch.randn(3) * xyz_random * int(use_random) | |
return torch.tensor([position[0], position[1], position[2], | |
altitude, angle, 0, | |
focal_length, plane_width], dtype=torch.float, device=device) | |
def render(self, cam_params=None): | |
if cam_params is None: | |
cam_params = self.camera(0, 0) | |
vert_pos, vert_col, radius, opacity = self.get_points() | |
rgb = self.renderer(vert_pos, vert_col, radius, cam_params, | |
self.args.gamma, self.args.max_depth, opacity=opacity) | |
opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params, | |
self.args.gamma, self.args.max_depth, opacity=opacity) | |
return rgb, opacity | |
def random_view_render(self): | |
angle = random.uniform(0, np.pi * 2) | |
altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2) | |
cam_params = self.camera(angle, altitude) | |
result, alpha = self.render(cam_params) | |
back = torch.zeros_like(result) | |
s = back.shape | |
for j in range(s[-1]): | |
n = random.choice([7, 14, 28]) | |
back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5 | |
result = result * (1 - alpha) + back * alpha | |
return result | |
def generate(self): | |
self.optimizer.zero_grad() | |
try: | |
for i in trange(self.args.iterations + self.args.showoff): | |
if i < self.args.iterations: | |
result = self.random_view_render() | |
img_emb = self.model_clip.encode_image( | |
self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.)) | |
img_emb = torch.nn.functional.normalize(img_emb, dim=-1) | |
if self.args.loss_type == "spherical": | |
clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean() | |
elif self.args.loss_type == "cosine": | |
clip_loss = (1 - img_emb @ self.txt_emb.T).mean() | |
else: | |
raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}") | |
loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types | |
loss.backward() | |
if i % self.args.augments == self.args.augments - 1: | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
try: | |
self.scheduler.step() | |
except AttributeError: | |
pass | |
if i % self.args.show_every == 0: | |
cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False) | |
img_show, _ = self.render(cam_params) | |
img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8)) | |
yield img | |
except KeyboardInterrupt: | |
pass | |
def save_obj(self, fn): | |
utils.save_obj(self.get_points(), fn) | |
class DotDict(dict): | |
def __getattr__(self, item): | |
return self.__getitem__(item) | |