File size: 6,985 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
'''
@author: Zhigang Jiang
@time: 2022/05/23
@description:
'''

import gradio as gr
import numpy as np
import os
import torch

from PIL import Image

from utils.logger import get_logger
from config.defaults import get_config
from inference import preprocess, run_one_inference
from models.build import build_model
from argparse import Namespace
import gdown


def down_ckpt(model_cfg, ckpt_dir):
    model_ids = [
        ['src/config/mp3d.yaml', '1o97oAmd-yEP5bQrM0eAWFPLq27FjUDbh'],
        ['src/config/zind.yaml', '1PzBj-dfDfH_vevgSkRe5kczW0GVl_43I'],
        ['src/config/pano.yaml', '1JoeqcPbm_XBPOi6O9GjjWi3_rtyPZS8m'],
        ['src/config/s2d3d.yaml', '1PfJzcxzUsbwwMal7yTkBClIFgn8IdEzI'],
        ['src/config/ablation_study/full.yaml', '1U16TxUkvZlRwJNaJnq9nAUap-BhCVIha']
    ]

    for model_id in model_ids:
        if model_id[0] != model_cfg:
            continue
        path = os.path.join(ckpt_dir, 'best.pkl')
        if not os.path.exists(path):
            logger.info(f"Downloading {model_id}")
            os.makedirs(ckpt_dir, exist_ok=True)
            gdown.download(f"https://drive.google.com/uc?id={model_id[1]}", path, False)


def greet(img_path, pre_processing, weight_name, post_processing, visualization, mesh_format, mesh_resolution):
    args.pre_processing = pre_processing
    args.post_processing = post_processing
    if weight_name == 'mp3d':
        model = mp3d_model
    elif weight_name == 'zind':
        model = zind_model
    else:
        logger.error("unknown pre-trained weight name")
        raise NotImplementedError

    img_name = os.path.basename(img_path).split('.')[0]
    img = np.array(Image.open(img_path).resize((1024, 512), Image.Resampling.BICUBIC))[..., :3]

    vp_cache_path = 'src/demo/default_vp.txt'
    if args.pre_processing:
        vp_cache_path = os.path.join('src/output', f'{img_name}_vp.txt')
        logger.info("pre-processing ...")
        img, vp = preprocess(img, vp_cache_path=vp_cache_path)

    img = (img / 255.0).astype(np.float32)
    run_one_inference(img, model, args, img_name,
                      logger=logger, show=False,
                      show_depth='depth-normal-gradient' in visualization,
                      show_floorplan='2d-floorplan' in visualization,
                      mesh_format=mesh_format, mesh_resolution=int(mesh_resolution))

    return [os.path.join(args.output_dir, f"{img_name}_pred.png"),
            os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
            os.path.join(args.output_dir, f"{img_name}_3d{mesh_format}"),
            vp_cache_path,
            os.path.join(args.output_dir, f"{img_name}_pred.json")]


def get_model(args):
    config = get_config(args)
    down_ckpt(args.cfg, config.CKPT.DIR)
    if ('cuda' in args.device or 'cuda' in config.TRAIN.DEVICE) and not torch.cuda.is_available():
        logger.info(f'The {args.device} is not available, will use cpu ...')
        config.defrost()
        args.device = "cpu"
        config.TRAIN.DEVICE = "cpu"
        config.freeze()
    model, _, _, _ = build_model(config, logger)
    return model


if __name__ == '__main__':
    logger = get_logger()
    args = Namespace(device='cuda', output_dir='src/output', visualize_3d=False, output_3d=True)
    os.makedirs(args.output_dir, exist_ok=True)

    args.cfg = 'src/config/mp3d.yaml'
    mp3d_model = get_model(args)

    args.cfg = 'src/config/zind.yaml'
    zind_model = get_model(args)

    description = "This demo of the project " \
                  "<a href='https://github.com/zhigangjiang/LGT-Net' target='_blank'>LGT-Net</a>. " \
                  "It uses the Geometry-Aware Transformer Network to predict the 3d room layout of an rgb panorama."

    demo = gr.Interface(fn=greet,
                        inputs=[gr.Image(type='filepath', label='input rgb panorama', value='src/demo/pano_demo1.png'),
                                gr.Checkbox(label='pre-processing', value=True),
                                gr.Radio(['mp3d', 'zind'],
                                         label='pre-trained weight',
                                         value='mp3d'),
                                gr.Radio(['manhattan', 'atalanta', 'original'],
                                         label='post-processing method',
                                         value='manhattan'),
                                gr.CheckboxGroup(['depth-normal-gradient', '2d-floorplan'],
                                                 label='2d-visualization',
                                                 value=['depth-normal-gradient', '2d-floorplan']),
                                gr.Radio(['.gltf', '.obj', '.glb'],
                                         label='output format of 3d mesh',
                                         value='.gltf'),
                                gr.Radio(['128', '256', '512', '1024'],
                                         label='output resolution of 3d mesh',
                                         value='256'),
                                ],
                        outputs=[gr.Image(label='predicted result 2d-visualization', type='filepath'),
                                 gr.Model3D(label='3d mesh reconstruction', clear_color=[1.0, 1.0, 1.0, 1.0]),
                                 gr.File(label='3d mesh file'),
                                 gr.File(label='vanishing point information'),
                                 gr.File(label='layout json')],
                        examples=[
                            ['src/demo/pano_demo1.png',  True,  'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/mp3d_demo1.png',  False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/mp3d_demo2.png',  False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/mp3d_demo3.png',  False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/zind_demo1.png',  True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/zind_demo2.png',  False, 'zind',  'atalanta', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/zind_demo3.png',  True, 'zind', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/other_demo1.png', False, 'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                            ['src/demo/other_demo2.png', True,  'mp3d', 'manhattan', ['depth-normal-gradient', '2d-floorplan'], '.gltf', '256'],
                        ], title='LGT-Net', allow_flagging="never", cache_examples=False, description=description)

    demo.launch(debug=True, enable_queue=False)