Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import xatlas | |
import trimesh | |
import moderngl | |
from PIL import Image | |
def make_atlas(mesh, texture_resolution, texture_padding): | |
atlas = xatlas.Atlas() | |
atlas.add_mesh(mesh.vertices, mesh.faces) | |
options = xatlas.PackOptions() | |
options.resolution = texture_resolution | |
options.padding = texture_padding | |
options.bilinear = True | |
atlas.generate(pack_options=options) | |
vmapping, indices, uvs = atlas[0] | |
return { | |
"vmapping": vmapping, | |
"indices": indices, | |
"uvs": uvs, | |
} | |
def rasterize_position_atlas( | |
mesh, atlas_vmapping, atlas_indices, atlas_uvs, texture_resolution, texture_padding | |
): | |
ctx = moderngl.create_context(standalone=True) | |
basic_prog = ctx.program( | |
vertex_shader=""" | |
#version 330 | |
in vec2 in_uv; | |
in vec3 in_pos; | |
out vec3 v_pos; | |
void main() { | |
v_pos = in_pos; | |
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0); | |
} | |
""", | |
fragment_shader=""" | |
#version 330 | |
in vec3 v_pos; | |
out vec4 o_col; | |
void main() { | |
o_col = vec4(v_pos, 1.0); | |
} | |
""", | |
) | |
gs_prog = ctx.program( | |
vertex_shader=""" | |
#version 330 | |
in vec2 in_uv; | |
in vec3 in_pos; | |
out vec3 vg_pos; | |
void main() { | |
vg_pos = in_pos; | |
gl_Position = vec4(in_uv * 2.0 - 1.0, 0.0, 1.0); | |
} | |
""", | |
geometry_shader=""" | |
#version 330 | |
uniform float u_resolution; | |
uniform float u_dilation; | |
layout (triangles) in; | |
layout (triangle_strip, max_vertices = 12) out; | |
in vec3 vg_pos[]; | |
out vec3 vf_pos; | |
void lineSegment(int aidx, int bidx) { | |
vec2 a = gl_in[aidx].gl_Position.xy; | |
vec2 b = gl_in[bidx].gl_Position.xy; | |
vec3 aCol = vg_pos[aidx]; | |
vec3 bCol = vg_pos[bidx]; | |
vec2 dir = normalize((b - a) * u_resolution); | |
vec2 offset = vec2(-dir.y, dir.x) * u_dilation / u_resolution; | |
gl_Position = vec4(a + offset, 0.0, 1.0); | |
vf_pos = aCol; | |
EmitVertex(); | |
gl_Position = vec4(a - offset, 0.0, 1.0); | |
vf_pos = aCol; | |
EmitVertex(); | |
gl_Position = vec4(b + offset, 0.0, 1.0); | |
vf_pos = bCol; | |
EmitVertex(); | |
gl_Position = vec4(b - offset, 0.0, 1.0); | |
vf_pos = bCol; | |
EmitVertex(); | |
} | |
void main() { | |
lineSegment(0, 1); | |
lineSegment(1, 2); | |
lineSegment(2, 0); | |
EndPrimitive(); | |
} | |
""", | |
fragment_shader=""" | |
#version 330 | |
in vec3 vf_pos; | |
out vec4 o_col; | |
void main() { | |
o_col = vec4(vf_pos, 1.0); | |
} | |
""", | |
) | |
uvs = atlas_uvs.flatten().astype("f4") | |
pos = mesh.vertices[atlas_vmapping].flatten().astype("f4") | |
indices = atlas_indices.flatten().astype("i4") | |
vbo_uvs = ctx.buffer(uvs) | |
vbo_pos = ctx.buffer(pos) | |
ibo = ctx.buffer(indices) | |
vao_content = [ | |
vbo_uvs.bind("in_uv", layout="2f"), | |
vbo_pos.bind("in_pos", layout="3f"), | |
] | |
basic_vao = ctx.vertex_array(basic_prog, vao_content, ibo) | |
gs_vao = ctx.vertex_array(gs_prog, vao_content, ibo) | |
fbo = ctx.framebuffer( | |
color_attachments=[ | |
ctx.texture((texture_resolution, texture_resolution), 4, dtype="f4") | |
] | |
) | |
fbo.use() | |
fbo.clear(0.0, 0.0, 0.0, 0.0) | |
gs_prog["u_resolution"].value = texture_resolution | |
gs_prog["u_dilation"].value = texture_padding | |
gs_vao.render() | |
basic_vao.render() | |
fbo_bytes = fbo.color_attachments[0].read() | |
fbo_np = np.frombuffer(fbo_bytes, dtype="f4").reshape( | |
texture_resolution, texture_resolution, 4 | |
) | |
return fbo_np | |
def positions_to_colors(model, scene_code, positions_texture, texture_resolution): | |
positions = torch.tensor(positions_texture.reshape(-1, 4)[:, :-1]) | |
with torch.no_grad(): | |
queried_grid = model.renderer.query_triplane( | |
model.decoder, | |
positions, | |
scene_code, | |
) | |
rgb_f = queried_grid["color"].numpy().reshape(-1, 3) | |
rgba_f = np.insert(rgb_f, 3, positions_texture.reshape(-1, 4)[:, -1], axis=1) | |
rgba_f[rgba_f[:, -1] == 0.0] = [0, 0, 0, 0] | |
return rgba_f.reshape(texture_resolution, texture_resolution, 4) | |
def bake_texture(mesh, model, scene_code, texture_resolution): | |
texture_padding = round(max(2, texture_resolution / 256)) | |
atlas = make_atlas(mesh, texture_resolution, texture_padding) | |
positions_texture = rasterize_position_atlas( | |
mesh, | |
atlas["vmapping"], | |
atlas["indices"], | |
atlas["uvs"], | |
texture_resolution, | |
texture_padding, | |
) | |
colors_texture = positions_to_colors( | |
model, scene_code, positions_texture, texture_resolution | |
) | |
return { | |
"vmapping": atlas["vmapping"], | |
"indices": atlas["indices"], | |
"uvs": atlas["uvs"], | |
"colors": colors_texture, | |
} | |