Spaces:
Build error
Build error
# A reimplemented version in public environments by Xiao Fu and Mu Hu | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import sys | |
sys.path.append("..") | |
from dataloader.mix_loader import MixDataset | |
from torch.utils.data import DataLoader | |
from dataloader import transforms | |
import os | |
# Get Dataset Here | |
def prepare_dataset(data_dir=None, | |
batch_size=1, | |
test_batch=1, | |
datathread=4, | |
logger=None): | |
# set the config parameters | |
dataset_config_dict = dict() | |
train_dataset = MixDataset(data_dir=data_dir) | |
img_height, img_width = train_dataset.get_img_size() | |
datathread = datathread | |
if os.environ.get('datathread') is not None: | |
datathread = int(os.environ.get('datathread')) | |
if logger is not None: | |
logger.info("Use %d processes to load data..." % datathread) | |
train_loader = DataLoader(train_dataset, batch_size = batch_size, \ | |
shuffle = True, num_workers = datathread, \ | |
pin_memory = True) | |
num_batches_per_epoch = len(train_loader) | |
dataset_config_dict['num_batches_per_epoch'] = num_batches_per_epoch | |
dataset_config_dict['img_size'] = (img_height,img_width) | |
return train_loader, dataset_config_dict | |
def depth_scale_shift_normalization(depth): | |
bsz = depth.shape[0] | |
depth_ = depth[:,0,:,:].reshape(bsz,-1).cpu().numpy() | |
min_value = torch.from_numpy(np.percentile(a=depth_,q=2,axis=1)).to(depth)[...,None,None,None] | |
max_value = torch.from_numpy(np.percentile(a=depth_,q=98,axis=1)).to(depth)[...,None,None,None] | |
normalized_depth = ((depth - min_value)/(max_value-min_value+1e-5) - 0.5) * 2 | |
normalized_depth = torch.clip(normalized_depth, -1., 1.) | |
return normalized_depth | |
def resize_max_res_tensor(input_tensor, mode, recom_resolution=768): | |
assert input_tensor.shape[1]==3 | |
original_H, original_W = input_tensor.shape[2:] | |
downscale_factor = min(recom_resolution/original_H, recom_resolution/original_W) | |
if mode == 'normal': | |
resized_input_tensor = F.interpolate(input_tensor, | |
scale_factor=downscale_factor, | |
mode='nearest') | |
else: | |
resized_input_tensor = F.interpolate(input_tensor, | |
scale_factor=downscale_factor, | |
mode='bilinear', | |
align_corners=False) | |
if mode == 'depth': | |
return resized_input_tensor / downscale_factor | |
else: | |
return resized_input_tensor | |