|
import os |
|
import sys |
|
import time |
|
from datetime import datetime |
|
import logging |
|
import numpy as np |
|
import torch |
|
import math |
|
|
|
def get_timestamp(): |
|
return datetime.now().strftime('%y%m%d-%H%M%S') |
|
|
|
def mkdir_and_rename(path): |
|
if os.path.exists(path): |
|
new_name = path + '_archived_' + get_timestamp() |
|
print('Path already exists. Rename it to [{:s}]'.format(new_name)) |
|
os.rename(path, new_name) |
|
os.makedirs(path) |
|
|
|
|
|
def scandir(dir_path, suffix=None, recursive=False, full_path=False): |
|
"""Scan a directory to find the interested files. |
|
Args: |
|
dir_path (str): Path of the directory. |
|
suffix (str | tuple(str), optional): File suffix that we are |
|
interested in. Default: None. |
|
recursive (bool, optional): If set to True, recursively scan the |
|
directory. Default: False. |
|
full_path (bool, optional): If set to True, include the dir_path. |
|
Default: False. |
|
Returns: |
|
A generator for all the interested files with relative pathes. |
|
""" |
|
|
|
if (suffix is not None) and not isinstance(suffix, (str, tuple)): |
|
raise TypeError('"suffix" must be a string or tuple of strings') |
|
|
|
root = dir_path |
|
|
|
def _scandir(dir_path, suffix, recursive): |
|
for entry in os.scandir(dir_path): |
|
if not entry.name.startswith('.') and entry.is_file(): |
|
if full_path: |
|
return_path = entry.path |
|
else: |
|
return_path = os.path.relpath(entry.path, root) |
|
|
|
if suffix is None: |
|
yield return_path |
|
elif return_path.endswith(suffix): |
|
yield return_path |
|
else: |
|
if recursive: |
|
yield from _scandir( |
|
entry.path, suffix=suffix, recursive=recursive) |
|
else: |
|
continue |
|
|
|
return _scandir(dir_path, suffix=suffix, recursive=recursive) |
|
|
|
|
|
def setup_logger(log_file_path): |
|
log_formatter = logging.Formatter("%(asctime)s [%(levelname)-5.5s] %(message)s") |
|
root_logger = logging.getLogger() |
|
root_logger.setLevel(logging.INFO) |
|
|
|
log_file_handler = logging.FileHandler(log_file_path, encoding='utf-8') |
|
log_file_handler.setFormatter(log_formatter) |
|
root_logger.addHandler(log_file_handler) |
|
|
|
log_stream_handler = logging.StreamHandler(sys.stdout) |
|
log_stream_handler.setFormatter(log_formatter) |
|
root_logger.addHandler(log_stream_handler) |
|
|
|
logging.info('Logging file is %s' % log_file_path) |
|
|
|
|
|
def print_args(args): |
|
for arg in vars(args): |
|
logging.info(arg + ':%s'%(getattr(args, arg))) |