File size: 6,985 Bytes
88b0dcb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
'''
@author: Zhigang Jiang
@time: 2022/05/23
@description:
'''
import gradio as gr
import numpy as np
import os
import torch
from PIL import Image
from utils.logger import get_logger
from config.defaults import get_config
from inference import preprocess, run_one_inference
from models.build import build_model
from argparse import Namespace
import gdown
def down_ckpt(model_cfg, ckpt_dir):
model_ids = [
['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
]
for model_id in model_ids:
if model_id[0] != model_cfg:
continue
path = os.path.join(ckpt_dir, 'best.pkl')
if not os.path.exists(path):
logger.info(f"Downloading {model_id}")
os.makedirs(ckpt_dir, exist_ok=True)
gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)
def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
args.pre_processing = pre_processing
args.post_processing = post_processing
if weight_name == 'mp3d':
model = mp3d_model
elif weight_name == 'zind':
model = zind_model
else:
logger.error("unknown pre-trained weight name")
raise NotImplementedError
img_name = os.path.basename(img_path).split('.')[0]
img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]
vp_cache_path = 'src/demo/default_vp.txt'
if args.pre_processing:
vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
logger.info("pre-processing ...")
img, vp = preprocess(img, vp_cache_path=vp_cache_path)
img = (img / 255.0).astype(np.float32)
run_one_inference(img, model, args, img_name,
logger=logger, show=False,
show_depth='depth-normal-gradient' in visualization,
show_floorplan='2d-floorplan' in visualization,
mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))
return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
vp_cache_path,
os.path.join(args.output_dir, f"{img_name}_pred.json")]
def get_model(args):
config = get_config(args)
down_ckpt(args.cfg, config.CKPT.DIR)
if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
logger.info(f'The {args.device} is not available, will use cpu ...')
config.defrost()
args.device = "cpu"
config.TRAIN.DEVICE = "cpu"
config.freeze()
model, _, _, _ = build_model(config, logger)
return model
if __name__ == '__main__':
logger = get_logger()
args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
os.makedirs(args.output_dir, exist_ok=True)
args.cfg = 'src/config/mp3d.yaml'
mp3d_model = get_model(args)
args.cfg = 'src/config/zind.yaml'
zind_model = get_model(args)
description = "This demo of the project " \
"<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
"It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."
demo = gr.Interface(fn=greet,
inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
gr.Checkbox(label='pre-processing', value=True),
gr.Radio(['mp3d', 'zind'],
label='pre-trained weight',
value='mp3d'),
gr.Radio(['manhattan', 'atalanta', 'original'],
label='post-processing method',
value='manhattan'),
gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
label='2d-visualization',
value=['depth-normal-gradient', '2d-floorplan']),
gr.Radio(['.gltf', '.obj', '.glb'],
label='output format of 3d mesh',
value='.gltf'),
gr.Radio(['128', '256', '512', '1024'],
label='output resolution of 3d mesh',
value='256'),
],
outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
gr.File(label='3d mesh file'),
gr.File(label='vanishing point information'),
gr.File(label='layout json')],
examples=[
['src/demo/pano_demo1.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo2.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/mp3d_demo3.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo1.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo2.png', False, 'zind', 'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/zind_demo3.png', True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
['src/demo/other_demo2.png', True, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)
demo.launch(debug=True, enable_queue=False)
|