tex3 / tsr /system.py
hanshu.yan
add app.py
2ec72fb
import math
import os
from dataclasses import dataclass, field
from typing import List, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import trimesh
from einops import rearrange
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from .models.isosurface import MarchingCubeHelper
from .utils import (
BaseModule,
ImagePreprocessor,
find_class,
get_spherical_cameras,
scale_tensor,
)
class TSR(BaseModule):
@dataclass
class Config(BaseModule.Config):
cond_image_size: int
image_tokenizer_cls: str
image_tokenizer: dict
tokenizer_cls: str
tokenizer: dict
backbone_cls: str
backbone: dict
post_processor_cls: str
post_processor: dict
decoder_cls: str
decoder: dict
renderer_cls: str
renderer: dict
cfg: Config
@classmethod
def from_pretrained(
cls, pretrained_model_name_or_path: str, config_name: str, weight_name: str
):
if os.path.isdir(pretrained_model_name_or_path):
config_path = os.path.join(pretrained_model_name_or_path, config_name)
weight_path = os.path.join(pretrained_model_name_or_path, weight_name)
else:
config_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=config_name
)
weight_path = hf_hub_download(
repo_id=pretrained_model_name_or_path, filename=weight_name
)
cfg = OmegaConf.load(config_path)
OmegaConf.resolve(cfg)
model = cls(cfg)
ckpt = torch.load(weight_path, map_location="cpu")
model.load_state_dict(ckpt)
return model
def configure(self):
self.image_tokenizer = find_class(self.cfg.image_tokenizer_cls)(
self.cfg.image_tokenizer
)
self.tokenizer = find_class(self.cfg.tokenizer_cls)(self.cfg.tokenizer)
self.backbone = find_class(self.cfg.backbone_cls)(self.cfg.backbone)
self.post_processor = find_class(self.cfg.post_processor_cls)(
self.cfg.post_processor
)
self.decoder = find_class(self.cfg.decoder_cls)(self.cfg.decoder)
self.renderer = find_class(self.cfg.renderer_cls)(self.cfg.renderer)
self.image_processor = ImagePreprocessor()
self.isosurface_helper = None
def forward(
self,
image: Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
],
device: str,
) -> torch.FloatTensor:
rgb_cond = self.image_processor(image, self.cfg.cond_image_size)[:, None].to(
device
)
batch_size = rgb_cond.shape[0]
input_image_tokens: torch.Tensor = self.image_tokenizer(
rearrange(rgb_cond, "B Nv H W C -> B Nv C H W", Nv=1),
)
input_image_tokens = rearrange(
input_image_tokens, "B Nv C Nt -> B (Nv Nt) C", Nv=1
)
tokens: torch.Tensor = self.tokenizer(batch_size)
tokens = self.backbone(
tokens,
encoder_hidden_states=input_image_tokens,
)
scene_codes = self.post_processor(self.tokenizer.detokenize(tokens))
return scene_codes
def render(
self,
scene_codes,
n_views: int,
elevation_deg: float = 0.0,
camera_distance: float = 1.9,
fovy_deg: float = 40.0,
height: int = 256,
width: int = 256,
return_type: str = "pil",
):
rays_o, rays_d = get_spherical_cameras(
n_views, elevation_deg, camera_distance, fovy_deg, height, width
)
rays_o, rays_d = rays_o.to(scene_codes.device), rays_d.to(scene_codes.device)
def process_output(image: torch.FloatTensor):
if return_type == "pt":
return image
elif return_type == "np":
return image.detach().cpu().numpy()
elif return_type == "pil":
return Image.fromarray(
(image.detach().cpu().numpy() * 255.0).astype(np.uint8)
)
else:
raise NotImplementedError
images = []
for scene_code in scene_codes:
images_ = []
for i in range(n_views):
with torch.no_grad():
image = self.renderer(
self.decoder, scene_code, rays_o[i], rays_d[i]
)
images_.append(process_output(image))
images.append(images_)
return images
def set_marching_cubes_resolution(self, resolution: int):
if (
self.isosurface_helper is not None
and self.isosurface_helper.resolution == resolution
):
return
self.isosurface_helper = MarchingCubeHelper(resolution)
def extract_mesh(self, scene_codes, resolution: int = 256, threshold: float = 25.0):
self.set_marching_cubes_resolution(resolution)
meshes = []
for scene_code in scene_codes:
with torch.no_grad():
density = self.renderer.query_triplane(
self.decoder,
scale_tensor(
self.isosurface_helper.grid_vertices.to(scene_codes.device),
self.isosurface_helper.points_range,
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
),
scene_code,
)["density_act"]
v_pos, t_pos_idx = self.isosurface_helper(-(density - threshold))
v_pos = scale_tensor(
v_pos,
self.isosurface_helper.points_range,
(-self.renderer.cfg.radius, self.renderer.cfg.radius),
)
with torch.no_grad():
color = self.renderer.query_triplane(
self.decoder,
v_pos,
scene_code,
)["color"]
mesh = trimesh.Trimesh(
vertices=v_pos.cpu().numpy(),
faces=t_pos_idx.cpu().numpy(),
vertex_colors=color.cpu().numpy(),
)
meshes.append(mesh)
return meshes