|
import logging |
|
import os |
|
import torch |
|
from torchvision import transforms |
|
import numpy as np |
|
import random |
|
import cv2 |
|
from PIL import Image |
|
|
|
|
|
def path_to_image(path, size=(1024, 1024), color_type=['rgb', 'gray'][0]): |
|
if color_type.lower() == 'rgb': |
|
image = cv2.imread(path) |
|
elif color_type.lower() == 'gray': |
|
image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) |
|
else: |
|
print('Select the color_type to return, either to RGB or gray image.') |
|
return |
|
if size: |
|
image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) |
|
if color_type.lower() == 'rgb': |
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)).convert('RGB') |
|
else: |
|
image = Image.fromarray(image).convert('L') |
|
return image |
|
|
|
|
|
|
|
def check_state_dict(state_dict, unwanted_prefix='_orig_mod.'): |
|
for k, v in list(state_dict.items()): |
|
if k.startswith(unwanted_prefix): |
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
return state_dict |
|
|
|
|
|
def generate_smoothed_gt(gts): |
|
epsilon = 0.001 |
|
new_gts = (1-epsilon)*gts+epsilon/2 |
|
return new_gts |
|
|
|
|
|
class Logger(): |
|
def __init__(self, path="log.txt"): |
|
self.logger = logging.getLogger('BiRefNet') |
|
self.file_handler = logging.FileHandler(path, "w") |
|
self.stdout_handler = logging.StreamHandler() |
|
self.stdout_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) |
|
self.file_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s %(message)s')) |
|
self.logger.addHandler(self.file_handler) |
|
self.logger.addHandler(self.stdout_handler) |
|
self.logger.setLevel(logging.INFO) |
|
self.logger.propagate = False |
|
|
|
def info(self, txt): |
|
self.logger.info(txt) |
|
|
|
def close(self): |
|
self.file_handler.close() |
|
self.stdout_handler.close() |
|
|
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0.0 |
|
self.avg = 0.0 |
|
self.sum = 0.0 |
|
self.count = 0.0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
|
|
def save_checkpoint(state, path, filename="latest.pth"): |
|
torch.save(state, os.path.join(path, filename)) |
|
|
|
|
|
def save_tensor_img(tenor_im, path): |
|
im = tenor_im.cpu().clone() |
|
im = im.squeeze(0) |
|
tensor2pil = transforms.ToPILImage() |
|
im = tensor2pil(im) |
|
im.save(path) |
|
|
|
|
|
def set_seed(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
np.random.seed(seed) |
|
random.seed(seed) |
|
torch.backends.cudnn.deterministic = True |
|
|