S23DR-P2R / script.py
colin1842's picture
Upload 28 files
5a60eac verified
### This is example of the script that will be run in the test environment.
### Some parts of the code are compulsory and you should NOT CHANGE THEM.
### They are between '''---compulsory---''' comments.
### You can change the rest of the code to define and test your solution.
### However, you should not change the signature of the provided function.
### The script would save "submission.parquet" file in the current directory.
### The actual logic of the solution is implemented in the `handcrafted_solution.py` file.
### The `handcrafted_solution.py` file is a placeholder for your solution.
### You should implement the logic of your solution in that file.
### You can use any additional files and subdirectories to organize your code.
'''---compulsory---'''
# import subprocess
# from pathlib import Path
# def install_package_from_local_file(package_name, folder='packages'):
# """
# Installs a package from a local .whl file or a directory containing .whl files using pip.
# Parameters:
# path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
# """
# try:
# pth = str(Path(folder) / package_name)
# subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
# "--no-index", # Do not use package index
# "--find-links", pth, # Look for packages in the specified directory or at the file
# package_name]) # Specify the package to install
# print(f"Package installed successfully from {pth}")
# except subprocess.CalledProcessError as e:
# print(f"Failed to install package from {pth}. Error: {e}")
# install_package_from_local_file('hoho')
import hoho; hoho.setup() # YOU MUST CALL hoho.setup() BEFORE ANYTHING ELSE
# import subprocess
# import importlib
# from pathlib import Path
# import subprocess
# ### The function below is useful for installing additional python wheels.
# def install_package_from_local_file(package_name, folder='packages'):
# """
# Installs a package from a local .whl file or a directory containing .whl files using pip.
# Parameters:
# path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
# """
# try:
# pth = str(Path(folder) / package_name)
# subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
# "--no-index", # Do not use package index
# "--find-links", pth, # Look for packages in the specified directory or at the file
# package_name]) # Specify the package to install
# print(f"Package installed successfully from {pth}")
# except subprocess.CalledProcessError as e:
# print(f"Failed to install package from {pth}. Error: {e}")
# pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
# install_package_from_local_file('webdataset')
# install_package_from_local_file('tqdm')
### Here you can import any library or module you want.
### The code below is used to read and parse the input dataset.
### Please, do not modify it.
import subprocess
import sys
import os
# Setup environment and install necessary packages
def setup_environment():
subprocess.check_call([sys.executable, "-m", "pip", "install", "git+http://hf.co/usm3d/tools.git"])
import hoho
hoho.setup()
subprocess.check_call([sys.executable, "-m", "pip", "install", "torch==2.0.1", "torchvision==0.15.2", "torchaudio==2.0.2", "-f", "https://download.pytorch.org/whl/cu117.html"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "tqdm"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "scipy"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "open3d"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "easydict"])
# pc_util_path = os.path.join(os.getcwd(), 'pc_util')
# if os.path.isdir(pc_util_path):
# os.chdir(pc_util_path)
# subprocess.check_call([sys.executable, "setup.py", "install"])
# else:
# print(f"Directory {pc_util_path} does not exist")
import webdataset as wds
from tqdm import tqdm
from typing import Dict
import pandas as pd
# from transformers import AutoTokenizer
import os
import time
import io
from PIL import Image as PImage
import numpy as np
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
from hoho import proc, Sample
### Ours Import Settings
import os
import torch
import torch.nn as nn
import argparse
import datetime
import glob
import torch.distributed as dist
from dataset.data_utils import build_dataloader
from test_util import test_model
from model.roofnet import RoofNet
from torch import optim
from utils import common_utils
from model import model_utils
import webdataset as wds
from tqdm import tqdm
from typing import Dict
import pandas as pd
# from transformer import AutoTokenizer
import os
import time
import io
from PIL import Image as PImage
import numpy as np
from hoho.read_write_colmap import read_cameras_binary, read_images_binary, read_points3D_binary
from hoho import proc, Sample
def remove_z_outliers(pcd_data, low_threshold_percentage=50, high_threshold_percentage=0):
"""
Remove outliers from a point cloud data based on z-value.
Parameters:
- pcd_data (np.array): Nx3 numpy array containing the point cloud data.
- low_threshold_percentage (float): Percentage of points to be removed based on the lowest z-values.
- high_threshold_percentage (float): Percentage of points to be removed based on the highest z-values.
Returns:
- np.array: Filtered point cloud data as a Nx3 numpy array.
"""
num_std=3
low_z_threshold = np.percentile(pcd_data[:, 2], low_threshold_percentage)
high_z_threshold = np.percentile(pcd_data[:, 2], 100 - high_threshold_percentage)
mean_z, std_z = np.mean(pcd_data[:, 2]), np.std(pcd_data[:, 2])
z_range = (mean_z - num_std * std_z, mean_z + num_std * std_z)
# filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold) & (pcd_data[:, 2] < z_range[1])]
filtered_pcd_data = pcd_data[(pcd_data[:, 2] > low_z_threshold)]
return filtered_pcd_data
def convert_entry_to_human_readable(entry):
out = {}
already_good = ['__key__', 'wf_vertices', 'wf_edges', 'edge_semantics', 'mesh_vertices', 'mesh_faces', 'face_semantics', 'K', 'R', 't']
for k, v in entry.items():
if k in already_good:
out[k] = v
continue
if k == 'points3d':
out[k] = read_points3D_binary(fid=io.BytesIO(v))
if k == 'cameras':
out[k] = read_cameras_binary(fid=io.BytesIO(v))
if k == 'images':
out[k] = read_images_binary(fid=io.BytesIO(v))
if k in ['ade20k', 'gestalt']:
out[k] = [PImage.open(io.BytesIO(x)).convert('RGB') for x in v]
if k == 'depthcm':
out[k] = [PImage.open(io.BytesIO(x)) for x in entry['depthcm']]
return out
def parse_config():
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, default='Data/hoho_data_train', help='dataset path')
parser.add_argument('--cfg_file', type=str, default='./model_cfg.yaml', help='model config for training')
parser.add_argument('--batch_size', type=int, default=1, help='batch size for training')
parser.add_argument('--gpu', type=str, default='0', help='gpu for training')
parser.add_argument('--test_tag', type=str, default='hoho_train', help='extra tag for this experiment')
args = parser.parse_args()
cfg = common_utils.cfg_from_yaml_file(args.cfg_file)
return args, cfg
def save_submission(submission, path):
"""
Saves the submission to a specified path.
Parameters:
submission (List[Dict[]]): The submission to save.
path (str): The path to save the submission to.
"""
sub = pd.DataFrame(submission, columns=["__key__", "wf_vertices", "wf_edges"])
sub['wf_edges'] = sub['wf_edges'].apply(lambda x: x.tolist()) # Convert to list of lists
sub.to_parquet(path)
print(f"Submission saved to {path}")
def main():
# setup packages
# setup_environment()
args, cfg = parse_config()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
extra_tag = args.test_tag
output_dir = cfg.ROOT_DIR / 'output' / extra_tag
assert output_dir.exists(), '%s does not exist!!!' % str(output_dir)
ckpt_dir = output_dir #/ 'ckpt'
output_dir = output_dir / 'test'
output_dir.mkdir(parents=True, exist_ok=True)
log_file = output_dir / 'log.txt'
logger = common_utils.create_logger(log_file)
logger.info('**********************Start logging**********************')
for key, val in vars(args).items():
logger.info('{:16} {}'.format(key, val))
common_utils.log_config_to_file(cfg, logger=logger)
print ("------------ Loading dataset------------ ")
params = hoho.get_params()
dataset = hoho.get_dataset(decode=None, split='all', dataset_type='webdataset')
# dataset = dataset.decode()
# dataset = dataset.map(proc)
for entry in tqdm(dataset, desc="Processing entries"):
human_entry = convert_entry_to_human_readable(entry)
# human_entry = entry
key = human_entry['__key__']
points3D = human_entry['points3d']
xyz_ = np.stack([p.xyz for p in points3D.values()])
xyz = remove_z_outliers(xyz_, low_threshold_percentage=30, high_threshold_percentage=1.0)
#TODO: from webd dataset to ours dataloader roofn3d_dataset.py L152
test_loader = build_dataloader(key, xyz, args.batch_size, cfg.DATA, logger=logger)
net = RoofNet(cfg.MODEL)
net.cuda()
net.eval()
ckpt_list = glob.glob(str(ckpt_dir / '*checkpoint_epoch_*.pth'))
if len(ckpt_list) > 0:
ckpt_list.sort(key=os.path.getmtime)
model_utils.load_params(net, ckpt_list[-1], logger=logger)
logger.info('**********************Start testing**********************')
logger.info(net)
solution = []
for sample in tqdm(test_loader):
key, pred_vertices, pred_edges = test_model(net, test_loader, logger)
solution.append({
'__key__': key,
'wf_vertices': pred_vertices.tolist(),
'wf_edges': pred_edges
})
print(f"predict solution: {key}")
# save_submission(solution, output_dir / "submission.parquet")
print("saving submission")
save_submission(solution, "submission.parquet")
# test_model(net, test_loader, logger)
if __name__ == '__main__':
main()