|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''---compulsory---''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import hoho; hoho.setup() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import subprocess |
|
import sys |
|
import os |
|
|
|
def setup_environment(): |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+http://hf.co/usm3d/tools.git"]) |
|
import hoho |
|
hoho.setup() |
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2", "torchaudio==2.0.2", "-f", "https://download.pytorch.org/whl/cu117.html"]) |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"]) |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"]) |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"]) |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"]) |
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "easydict"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import webdataset as wds |
|
from tqdm import tqdm |
|
from typing import Dict |
|
import pandas as pd |
|
|
|
import os |
|
import time |
|
import io |
|
from PIL import Image as PImage |
|
import numpy as np |
|
|
|
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary |
|
from hoho import proc, Sample |
|
|
|
|
|
import os |
|
import torch |
|
import torch.nn as nn |
|
import argparse |
|
import datetime |
|
import glob |
|
import torch.distributed as dist |
|
from dataset.data_utils import build_dataloader |
|
from test_util import test_model |
|
from model.roofnet import RoofNet |
|
from torch import optim |
|
from utils import common_utils |
|
from model import model_utils |
|
|
|
import webdataset as wds |
|
from tqdm import tqdm |
|
from typing import Dict |
|
import pandas as pd |
|
|
|
import os |
|
import time |
|
import io |
|
from PIL import Image as PImage |
|
import numpy as np |
|
|
|
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary |
|
from hoho import proc, Sample |
|
|
|
def remove_z_outliers(pcd_data, low_threshold_percentage=50, high_threshold_percentage=0): |
|
""" |
|
Remove outliers from a point cloud data based on z-value. |
|
|
|
Parameters: |
|
- pcd_data (np.array): Nx3 numpy array containing the point cloud data. |
|
- low_threshold_percentage (float): Percentage of points to be removed based on the lowest z-values. |
|
- high_threshold_percentage (float): Percentage of points to be removed based on the highest z-values. |
|
|
|
Returns: |
|
- np.array: Filtered point cloud data as a Nx3 numpy array. |
|
""" |
|
num_std=3 |
|
low_z_threshold = np.percentile(pcd_data[:, 2], low_threshold_percentage) |
|
high_z_threshold = np.percentile(pcd_data[:, 2], 100 - high_threshold_percentage) |
|
mean_z, std_z = np.mean(pcd_data[:, 2]), np.std(pcd_data[:, 2]) |
|
z_range = (mean_z - num_std * std_z, mean_z + num_std * std_z) |
|
|
|
|
|
filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold)] |
|
|
|
return filtered_pcd_data |
|
|
|
def convert_entry_to_human_readable(entry): |
|
out = {} |
|
already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't'] |
|
for k, v in entry.items(): |
|
if k in already_good: |
|
out[k] = v |
|
continue |
|
if k == 'points3d': |
|
out[k] = read_points3D_binary(fid=io.BytesIO(v)) |
|
if k == 'cameras': |
|
out[k] = read_cameras_binary(fid=io.BytesIO(v)) |
|
if k == 'images': |
|
out[k] = read_images_binary(fid=io.BytesIO(v)) |
|
if k in ['ade20k', 'gestalt']: |
|
out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v] |
|
if k == 'depthcm': |
|
out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']] |
|
return out |
|
|
|
def parse_config(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--data_path', type=str, default='Data/hoho_data_train', help='dataset path') |
|
parser.add_argument('--cfg_file', type=str, default='./model_cfg.yaml', help='model config for training') |
|
parser.add_argument('--batch_size', type=int, default=1, help='batch size for training') |
|
parser.add_argument('--gpu', type=str, default='0', help='gpu for training') |
|
parser.add_argument('--test_tag', type=str, default='hoho_train', help='extra tag for this experiment') |
|
|
|
args = parser.parse_args() |
|
cfg = common_utils.cfg_from_yaml_file(args.cfg_file) |
|
return args, cfg |
|
|
|
def save_submission(submission, path): |
|
""" |
|
Saves the submission to a specified path. |
|
|
|
Parameters: |
|
submission (List[Dict[]]): The submission to save. |
|
path (str): The path to save the submission to. |
|
""" |
|
sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"]) |
|
sub['wf_edges'] = sub['wf_edges'].apply(lambda x: x.tolist()) |
|
sub.to_parquet(path) |
|
print(f"Submission saved to {path}") |
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
args, cfg = parse_config() |
|
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu |
|
|
|
extra_tag = args.test_tag |
|
output_dir = cfg.ROOT_DIR / 'output' / extra_tag |
|
assert output_dir.exists(), '%s does not exist!!!' % str(output_dir) |
|
ckpt_dir = output_dir |
|
output_dir = output_dir / 'test' |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
log_file = output_dir / 'log.txt' |
|
logger = common_utils.create_logger(log_file) |
|
|
|
logger.info('**********************Start logging**********************') |
|
for key, val in vars(args).items(): |
|
logger.info('{:16} {}'.format(key, val)) |
|
common_utils.log_config_to_file(cfg, logger=logger) |
|
|
|
print ("------------ Loading dataset------------ ") |
|
params = hoho.get_params() |
|
dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset') |
|
|
|
|
|
|
|
for entry in tqdm(dataset, desc="Processing entries"): |
|
human_entry = convert_entry_to_human_readable(entry) |
|
|
|
key = human_entry['__key__'] |
|
points3D = human_entry['points3d'] |
|
xyz_ = np.stack([p.xyz for p in points3D.values()]) |
|
xyz = remove_z_outliers(xyz_, low_threshold_percentage=30, high_threshold_percentage=1.0) |
|
|
|
test_loader = build_dataloader(key, xyz, args.batch_size, cfg.DATA, logger=logger) |
|
net = RoofNet(cfg.MODEL) |
|
net.cuda() |
|
net.eval() |
|
|
|
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth')) |
|
if len(ckpt_list) > 0: |
|
ckpt_list.sort(key=os.path.getmtime) |
|
model_utils.load_params(net, ckpt_list[-1], logger=logger) |
|
|
|
logger.info('**********************Start testing**********************') |
|
logger.info(net) |
|
|
|
solution = [] |
|
|
|
for sample in tqdm(test_loader): |
|
key, pred_vertices, pred_edges = test_model(net, test_loader, logger) |
|
solution.append({ |
|
'__key__': key, |
|
'wf_vertices': pred_vertices.tolist(), |
|
'wf_edges': pred_edges |
|
}) |
|
print(f"predict solution: {key}") |
|
|
|
|
|
print("saving submission") |
|
save_submission(solution, "submission.parquet") |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
main() |