LDM / core /models.py
xrg
update config
1b4cf24
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import mcubes
import kiui
from kiui.lpips import LPIPS
from core.lrm_reconstructor import LRM_VSD_Mesh_Net
from core.options import Options
from core.tensoRF import TensorVMSplit_Mesh,TensorVMSplit_NeRF
from torchvision.transforms import v2
from core.geometry.camera.perspective_camera import PerspectiveCamera
from core.geometry.render.neural_render import NeuralRender
from core.geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry
import nvdiffrast.torch as dr
from core.instant_utils.mesh_util import xatlas_uvmap
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#tensorSDF + transformer + volume_rendering
class LTRFM_NeRF(nn.Module):
def __init__(
self,
opt: Options,
):
super().__init__()
self.opt = opt
#predict svd using transformer
self.vsd_net = LRM_VSD_Mesh_Net(
camera_embed_dim=opt.camera_embed_dim,
transformer_dim=opt.transformer_dim,
transformer_layers=opt.transformer_layers,
transformer_heads=opt.transformer_heads,
triplane_low_res=opt.triplane_low_res,
triplane_high_res=opt.triplane_high_res,
triplane_dim=opt.triplane_dim,
encoder_freeze=opt.encoder_freeze,
encoder_type=opt.encoder_type,
encoder_model_name=opt.encoder_model_name,
encoder_feat_dim=opt.encoder_feat_dim,
app_dim=opt.app_dim,
density_dim=opt.density_dim,
app_n_comp=opt.app_n_comp,
density_n_comp=opt.density_n_comp,
)
aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).cuda()
grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).cuda()
near_far =torch.tensor([opt.znear, opt.zfar]).cuda()
# tensorf Renderer
self.tensorRF = TensorVMSplit_NeRF(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
density_dim=opt.density_dim,near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
# LPIPS loss
if self.opt.lambda_lpips > 0:
self.lpips_loss = LPIPS(net='vgg')
self.lpips_loss.requires_grad_(False)
def state_dict(self, **kwargs):
# remove lpips_loss
state_dict = super().state_dict(**kwargs)
for k in list(state_dict.keys()):
if 'lpips_loss' in k:
del state_dict[k]
return state_dict
def set_beta(self,t):
self.tensorRF.lap_density.set_beta(t)
# predict svd_volume
def forward_svd_volume(self, images, data):
# images: [B, 4, 9, H, W]
# return: Gaussians: [B, dim_t]
B, V, C, H, W = images.shape
source_camera=data['source_camera']
images_vit=data['input_vit'] # for transformer
source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
results = {
'app_planes': app_planes,
'app_lines': app_lines,
'basis_mat':basis_mat,
'd_basis_mat':d_basis_mat,
'density_planes':density_planes,
'density_lines':density_lines
}
return results
def extract_mesh(self,
planes: torch.Tensor,
mesh_resolution: int = 256,
mesh_threshold: int = 0.005,
use_texture_map: bool = False,
texture_resolution: int = 1024,):
device = planes['app_planes'].device
grid_size = mesh_resolution
points = torch.linspace(-1, 1, steps=grid_size).half()
x, y, z = torch.meshgrid(points, points, points)
xyz_samples = torch.stack((x, y, z), dim=0).unsqueeze(0).to(device)
xyz_samples=xyz_samples.permute(0,2,3,4,1)
xyz_samples=xyz_samples.view(1,-1,1,3)
grid_out = self.tensorRF.predict_sdf(planes,xyz_samples)
grid_out['sigma']=grid_out['sigma'].view(grid_size,grid_size,grid_size).float()
vertices, faces = mcubes.marching_cubes(
grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(),
mesh_threshold,
)
vertices = vertices / (mesh_resolution - 1) * 2 - 1
if not use_texture_map:
# query vertex colors
vertices_tensor = torch.tensor(vertices, dtype=torch.float32).to(device).unsqueeze(0)
rgb_colors = self.tensorRF.predict_color(
planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy()
rgb_colors = (rgb_colors * 255).astype(np.uint8)
albedob_colors = self.tensorRF.predict_color(
planes, vertices_tensor)['albedo'].squeeze(0).cpu().numpy()
albedob_colors = (albedob_colors * 255).astype(np.uint8)
shading_colors = self.tensorRF.predict_color(
planes, vertices_tensor)['shading'].squeeze(0).cpu().numpy()
shading_colors = (shading_colors * 255).astype(np.uint8)
return vertices, faces, [rgb_colors,albedob_colors,shading_colors]
# use x-atlas to get uv mapping for the mesh
vertices = torch.tensor(vertices, dtype=torch.float32, device=device)
faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device)
ctx = dr.RasterizeCudaContext(device=device)
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
ctx, vertices, faces, resolution=texture_resolution)
tex_hard_mask = tex_hard_mask.float().cpu()
# query the texture field to get the RGB color for texture map
#TBD here
query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
vertices_colors = self.tensorRF.predict_color(
planes, query_vertices)['rgb'].squeeze(0).cpu()
vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
background_feature = torch.zeros_like(vertices_colors)
img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask.half())
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
#albedo
vertices_colors_albedo = self.tensorRF.predict_color(
planes, query_vertices)['albedo'].squeeze(0).cpu()
vertices_colors_albedo=vertices_colors_albedo.reshape(1,texture_resolution,texture_resolution,3)
background_feature = torch.zeros_like(vertices_colors_albedo)
img_feat = torch.lerp(background_feature, vertices_colors_albedo, tex_hard_mask.half())
texture_map_albedo = img_feat.permute(0, 3, 1, 2).squeeze(0)
return vertices, faces, uvs, mesh_tex_idx, [texture_map,texture_map_albedo]
def render_frame(self, data):
# data: output of the dataloader
# return: loss
#self.set_beta(data['t'])
results = {}
loss = 0
images = data['input_vit']
# use the first view to predict gaussians
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
results['svd_volume'] = svd_volume
# always use white bg
bg_color = torch.ones(3, dtype=torch.float32).to(device)
# use the other views for rendering and supervision
results = self.tensorRF(svd_volume, data['all_rays_o'], data['all_rays_d'],is_train=True, bg_color=bg_color, N_samples=self.opt.n_sample)
pred_shading = results['image'] # [B, V, C, output_size, output_size]
pred_alphas = results['alpha'] # [B, V, 1, output_size, output_size]
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
pred_images = pred_shading*pred_albedos
results['images_pred'] = pred_images
results['alphas_pred'] = pred_alphas
results['pred_albedos'] = pred_albedos
results['pred_shading'] = pred_shading
return results
#tensorSDF + transformer + SDF + Mesh
class LTRFM_Mesh(nn.Module):
def __init__(
self,
opt: Options,
):
super().__init__()
self.opt = opt
# attributes
self.grid_res = 128 #grid_res
self.grid_scale = 2.0 #grid_scale
self.deformation_multiplier = 4.0
self.init_flexicubes_geometry(device, self.opt)
#predict svd using transformer
self.vsd_net = LRM_VSD_Mesh_Net(
camera_embed_dim=opt.camera_embed_dim,
transformer_dim=opt.transformer_dim,
transformer_layers=opt.transformer_layers,
transformer_heads=opt.transformer_heads,
triplane_low_res=opt.triplane_low_res,
triplane_high_res=opt.triplane_high_res,
triplane_dim=opt.triplane_dim,
encoder_freeze=opt.encoder_freeze,
encoder_type=opt.encoder_type,
encoder_model_name=opt.encoder_model_name,
encoder_feat_dim=opt.encoder_feat_dim,
app_dim=opt.app_dim,
density_dim=opt.density_dim,
app_n_comp=opt.app_n_comp,
density_n_comp=opt.density_n_comp,
)
aabb = torch.tensor([[-1, -1, -1], [1, 1, 1]]).to(device)
grid_size = torch.tensor([opt.splat_size, opt.splat_size, opt.splat_size]).to(device)
near_far =torch.tensor([opt.znear, opt.zfar]).to(device)
# tensorf Renderer
self.tensorRF = TensorVMSplit_Mesh(aabb, grid_size, density_n_comp=opt.density_n_comp,appearance_n_comp=opt.app_n_comp,app_dim=opt.app_dim,\
density_dim=opt.density_dim, near_far=near_far, shadingMode=opt.shadingMode, pos_pe=opt.pos_pe, view_pe=opt.view_pe, fea_pe=opt.fea_pe)
# LPIPS loss
if self.opt.lambda_lpips > 0:
self.lpips_loss = LPIPS(net='vgg')
self.lpips_loss.requires_grad_(False)
# load ckpt
if opt.ckpt_nerf is not None:
sd = torch.load(opt.ckpt_nerf, map_location='cpu')['model']
#sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
sd_fc = {}
for k, v in sd.items():
k=k.replace('module.', '')
if k.startswith('vsd.renderModule.'):
continue
else:
sd_fc[k] = v
sd_fc = {k.replace('vsd_net.', ''): v for k, v in sd_fc.items()}
sd_fc = {k.replace('tensorRF.', ''): v for k, v in sd_fc.items()}
# missing `net_deformation` and `net_weight` parameters
self.vsd_net.load_state_dict(sd_fc, strict=False)
self.tensorRF.load_state_dict(sd_fc, strict=False)
print(f'Loaded weights from {opt.ckpt_nerf}')
def state_dict(self, **kwargs):
# remove lpips_loss
state_dict = super().state_dict(**kwargs)
for k in list(state_dict.keys()):
if 'lpips_loss' in k:
del state_dict[k]
return state_dict
# predict svd_volume
def forward_svd_volume(self, images, data):
# images: [B, 4, 9, H, W]
# return: Gaussians: [B, dim_t]
B, V, C, H, W = images.shape
source_camera=data['source_camera']
images_vit=data['input_vit'] # for transformer
source_camera=source_camera.reshape(B,V,-1) # [B*V, 16]
app_planes,app_lines,basis_mat,d_basis_mat,density_planes,density_lines = self.vsd_net(images_vit,source_camera)
app_planes=app_planes.view(B,3,self.opt.app_n_comp,self.opt.splat_size,self.opt.splat_size)
app_lines=app_lines.view(B,3,self.opt.app_n_comp,self.opt.splat_size,1)
density_planes=density_planes.view(B,3,self.opt.density_n_comp,self.opt.splat_size,self.opt.splat_size)
density_lines=density_lines.view(B,3,self.opt.density_n_comp,self.opt.splat_size,1)
results = {
'app_planes': app_planes,
'app_lines': app_lines,
'basis_mat':basis_mat,
'd_basis_mat':d_basis_mat,
'density_planes':density_planes,
'density_lines':density_lines
}
return results
def init_flexicubes_geometry(self, device, opt):
camera = PerspectiveCamera(opt, device=device)
renderer = NeuralRender(device, camera_model=camera)
self.geometry = FlexiCubesGeometry(
grid_res=self.grid_res,
scale=self.grid_scale,
renderer=renderer,
render_type='neural_render',
device=device,
)
# query vsd for sdf weight and ...
def get_sdf_deformation_prediction(self, planes):
'''
Predict SDF and deformation for tetrahedron vertices
:param planes: triplane feature map for the geometry
'''
B = planes['app_lines'].shape[0]
init_position = self.geometry.verts.unsqueeze(0).expand(B, -1, -1)
sdf, deformation, weight = self.tensorRF.get_geometry_prediction(planes,init_position,self.geometry.indices)
deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation)
sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32)
sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1))
sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1)
pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1)
neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1)
zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0)
if torch.sum(zero_surface).item() > 0:
update_sdf = torch.zeros_like(sdf[0:1])
max_sdf = sdf.max()
min_sdf = sdf.min()
update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero
update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero
new_sdf = torch.zeros_like(sdf)
for i_batch in range(zero_surface.shape[0]):
if zero_surface[i_batch]:
new_sdf[i_batch:i_batch + 1] += update_sdf
update_mask = (new_sdf == 0).float()
# Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative)
sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1)
sdf_reg_loss = sdf_reg_loss * zero_surface.float()
sdf = sdf * update_mask + new_sdf * (1 - update_mask)
final_sdf = []
final_def = []
for i_batch in range(zero_surface.shape[0]):
if zero_surface[i_batch]:
final_sdf.append(sdf[i_batch: i_batch + 1].detach())
final_def.append(deformation[i_batch: i_batch + 1].detach())
else:
final_sdf.append(sdf[i_batch: i_batch + 1])
final_def.append(deformation[i_batch: i_batch + 1])
sdf = torch.cat(final_sdf, dim=0)
deformation = torch.cat(final_def, dim=0)
return sdf, deformation, sdf_reg_loss, weight
def get_geometry_prediction(self, planes=None):
'''
Function to generate mesh with give triplanes
:param planes: triplane features
'''
sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes)
v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation
tets = self.geometry.indices
n_batch = planes['app_planes'].shape[0]
v_list = []
f_list = []
flexicubes_surface_reg_list = []
for i_batch in range(n_batch):
verts, faces, flexicubes_surface_reg = self.geometry.get_mesh(
v_deformed[i_batch],
sdf[i_batch].squeeze(dim=-1),
with_uv=False,
indices=tets,
weight_n=weight[i_batch].squeeze(dim=-1),
is_training=self.training,
)
flexicubes_surface_reg_list.append(flexicubes_surface_reg)
v_list.append(verts)
f_list.append(faces)
flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean()
flexicubes_weight_reg = (weight ** 2).mean()
return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg)
def get_texture_prediction(self, planes, tex_pos, hard_mask=None):
'''
Predict Texture given triplanes
:param planes: the triplane feature map
:param tex_pos: Position we want to query the texture field
:param hard_mask: 2D silhoueete of the rendered image
'''
B = planes['app_planes'].shape[0]
tex_pos = torch.cat(tex_pos, dim=0)
if not hard_mask is None:
tex_pos = tex_pos * hard_mask.float()
batch_size = tex_pos.shape[0]
tex_pos = tex_pos.reshape(batch_size, -1, 3)
###################
# We use mask to get the texture location (to save the memory)
if hard_mask is not None:
n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1)
sample_tex_pose_list = []
max_point = n_point_list.max()
if max_point==0: # xrg: hard mask may filter all points, and don not left any point
max_point=max_point+1
expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5
for i in range(tex_pos.shape[0]):
tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3)
if tex_pos_one_shape.shape[1] < max_point:
tex_pos_one_shape = torch.cat(
[tex_pos_one_shape, torch.zeros(
1, max_point - tex_pos_one_shape.shape[1], 3,
device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1)
sample_tex_pose_list.append(tex_pos_one_shape)
tex_pos = torch.cat(sample_tex_pose_list, dim=0)
#return texture rgb
tex_feat = self.tensorRF.get_texture_prediction(tex_pos,vsd_vome=planes)
if hard_mask is not None:
final_tex_feat = torch.zeros(
B, hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device)
expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5
for i in range(B):
final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1)
tex_feat = final_tex_feat
return tex_feat.reshape(B, hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1])
def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256):
'''
Function to render a generated mesh with nvdiffrast
:param mesh_v: List of vertices for the mesh
:param mesh_f: List of faces for the mesh
:param cam_mv: 4x4 rotation matrix
:return:
'''
return_value_list = []
for i_mesh in range(len(mesh_v)):
return_value = self.geometry.render_mesh(
mesh_v[i_mesh],
mesh_f[i_mesh].int(),
cam_mv[i_mesh],
resolution=render_size,
hierarchical_mask=False
)
return_value_list.append(return_value)
return_keys = return_value_list[0].keys()
return_value = dict()
for k in return_keys:
value = [v[k] for v in return_value_list]
return_value[k] = value
mask = torch.cat(return_value['mask'], dim=0)
hard_mask = torch.cat(return_value['hard_mask'], dim=0)
tex_pos = return_value['tex_pos']
depth = torch.cat(return_value['depth'], dim=0)
normal = torch.cat(return_value['normal'], dim=0)
return mask, hard_mask, tex_pos, depth, normal
def forward_geometry(self, planes, render_cameras, render_size=256):
'''
Main function of our Generator. It first generate 3D mesh, then render it into 2D image
with given `render_cameras`.
:param planes: triplane features
:param render_cameras: cameras to render generated 3D shape, a w2c matrix
'''
B, NV = render_cameras.shape[:2]
# Generate 3D mesh first
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
# Render the mesh into 2D image (get 3d position of each image plane) continue for here
cam_mv = render_cameras
run_n_view = cam_mv.shape[1]
antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size)
tex_hard_mask = hard_mask
tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos]
tex_hard_mask = torch.cat(
[torch.cat(
[tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1]
for i_view in range(run_n_view)], dim=2)
for i in range(B)], dim=0)
# Querying the texture field to predict the texture feature for each pixel on the image
tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask)
background_feature = torch.ones_like(tex_feat) # white background
# Merge them together
img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask)
# We should split it back to the original image shape
img_feat = torch.cat(
[torch.cat(
[img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)]
for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0)
img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV))
albedo=img[:,:,3:6,:,:]
img=img[:,:,0:3,:,:]
antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV))
depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive
normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV))
out = {
'image': img,
'albedo': albedo,
'mask': antilias_mask,
'depth': depth,
'normal': normal,
'sdf': sdf,
'mesh_v': mesh_v,
'mesh_f': mesh_f,
'sdf_reg_loss': sdf_reg_loss,
}
return out
def render_frame(self, data):
# data: output of the dataloader
# return: loss
results = {}
images = data['input_vit'] # [B, 4, 9, h, W], input features
# use the first view to predict gaussians
svd_volume = self.forward_svd_volume(images,data) # [B, N, 14]
results['svd_volume'] = svd_volume
# return the rendered images
results = self.forward_geometry(svd_volume, data['w2c'], self.opt.infer_render_size)
# always use white bg
bg_color = torch.ones(3, dtype=torch.float32).to(device)
pred_shading = results['image'] # [B, V, C, output_size, output_size]
pred_alphas = results['mask'] # [B, V, 1, output_size, output_size]
pred_albedos = results['albedo'] # [B, V, C, output_size, output_size]
pred_images=pred_shading*pred_albedos
results['images_pred'] = pred_images
results['alphas_pred'] = pred_alphas
results['pred_albedos'] = pred_albedos
results['pred_shading'] = pred_shading
return results
def extract_mesh(
self,
planes: torch.Tensor,
use_texture_map: bool = False,
texture_resolution: int = 1024,
**kwargs,
):
'''
Extract a 3D mesh from FlexiCubes. Only support batch_size 1.
:param planes: triplane features
:param use_texture_map: use texture map or vertex color
:param texture_resolution: the resolution of texure map
'''
assert planes['app_planes'].shape[0] == 1
device = planes['app_planes'].device
# predict geometry first
mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes)
vertices, faces = mesh_v[0], mesh_f[0]
if not use_texture_map:
# query vertex colors
vertices_tensor = vertices.unsqueeze(0)
rgb_colors = self.tensorRF.predict_color(planes, vertices_tensor)['rgb'].clamp(0, 1).squeeze(0).cpu().numpy()
rgb_colors = (rgb_colors * 255).astype(np.uint8)
albedob_colors = self.tensorRF.predict_color(planes, vertices_tensor)['albedo'].clamp(0, 1).squeeze(0).cpu().numpy()
albedob_colors = (albedob_colors * 255).astype(np.uint8)
shading_colors = self.tensorRF.predict_color(planes, vertices_tensor)['shading'].clamp(0, 1).squeeze(0).cpu().numpy()
shading_colors = (shading_colors * 255).astype(np.uint8)
return vertices.cpu().numpy(), faces.cpu().numpy(), [rgb_colors,albedob_colors,shading_colors]
# use x-atlas to get uv mapping for the mesh
ctx = dr.RasterizeCudaContext(device=device)
uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap(
self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution)
tex_hard_mask = tex_hard_mask.float().cpu()
# query the texture field to get the RGB color for texture map
#TBD here
query_vertices=gb_pos.view(1,texture_resolution*texture_resolution,3)
vertices_colors = self.tensorRF.predict_color(
planes, query_vertices)['rgb'].squeeze(0).cpu()
vertices_colors=vertices_colors.reshape(1,texture_resolution,texture_resolution,3)
background_feature = torch.zeros_like(vertices_colors)
img_feat = torch.lerp(background_feature, vertices_colors, tex_hard_mask)
texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0)
return vertices, faces, uvs, mesh_tex_idx, texture_map