File size: 3,972 Bytes
82567db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import math
import os
from typing import Union, List, Tuple
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL.Image import Image
from scipy.spatial.distance import jensenshannon
from torch import Tensor
from torch.nn.functional import interpolate
from auxiliary.settings import DEVICE
def print_metrics(current_metrics: dict, best_metrics: dict):
print(" Mean ......... : {:.4f} (Best: {:.4f})".format(current_metrics["mean"], best_metrics["mean"]))
print(" Median ....... : {:.4f} (Best: {:.4f})".format(current_metrics["median"], best_metrics["median"]))
print(" Trimean ...... : {:.4f} (Best: {:.4f})".format(current_metrics["trimean"], best_metrics["trimean"]))
print(" Best 25% ..... : {:.4f} (Best: {:.4f})".format(current_metrics["bst25"], best_metrics["bst25"]))
print(" Worst 25% .... : {:.4f} (Best: {:.4f})".format(current_metrics["wst25"], best_metrics["wst25"]))
print(" Worst 5% ..... : {:.4f} (Best: {:.4f})".format(current_metrics["wst5"], best_metrics["wst5"]))
def correct(img: Image, illuminant: Tensor) -> Image:
"""
Corrects the color of the illuminant of a linear image based on an estimated (linear) illuminant
@param img: a linear image
@param illuminant: a linear illuminant
@return: a non-linear color-corrected version of the input image
"""
img = F.to_tensor(img).to(DEVICE)
# Correct the image
correction = illuminant.unsqueeze(2).unsqueeze(3) * torch.sqrt(Tensor([3])).to(DEVICE)
corrected_img = torch.div(img, correction + 1e-10)
# Normalize the image
max_img = torch.max(torch.max(torch.max(corrected_img, dim=1)[0], dim=1)[0], dim=1)[0] + 1e-10
max_img = max_img.unsqueeze(1).unsqueeze(1).unsqueeze(1)
normalized_img = torch.div(corrected_img, max_img)
return F.to_pil_image(linear_to_nonlinear(normalized_img).squeeze(), mode="RGB")
def linear_to_nonlinear(img: Union[np.array, Image, Tensor]) -> Union[np.array, Image, Tensor]:
if isinstance(img, np.ndarray):
return np.power(img, (1.0 / 2.2))
if isinstance(img, Tensor):
return torch.pow(img, 1.0 / 2.2)
return F.to_pil_image(torch.pow(F.to_tensor(img), 1.0 / 2.2).squeeze(), mode="RGB")
def normalize(img: np.ndarray) -> np.ndarray:
max_int = 65535.0
return np.clip(img, 0.0, max_int) * (1.0 / max_int)
def rgb_to_bgr(x: np.ndarray) -> np.ndarray:
return x[::-1]
def bgr_to_rgb(x: np.ndarray) -> np.ndarray:
return x[:, :, ::-1]
def hwc_to_chw(x: np.ndarray) -> np.ndarray:
""" Converts an image from height x width x channels to channels x height x width """
return x.transpose(2, 0, 1)
def scale(x: Tensor) -> Tensor:
""" Scales all values of a tensor between 0 and 1 """
x = x - x.min()
x = x / x.max()
return x
def rescale(x: Tensor, size: Tuple) -> Tensor:
""" Rescale tensor to image size for better visualization """
return interpolate(x, size, mode='bilinear')
def angular_error(x: Tensor, y: Tensor, safe_v: float = 0.999999) -> Tensor:
x, y = torch.nn.functional.normalize(x, dim=1), torch.nn.functional.normalize(y, dim=1)
dot = torch.clamp(torch.sum(x * y, dim=1), -safe_v, safe_v)
angle = torch.acos(dot) * (180 / math.pi)
return torch.mean(angle).item()
def tvd(pred: Tensor, label: Tensor) -> Tensor:
"""
Total Variation Distance (TVD) is a distance measure for probability distributions
https://en.wikipedia.org/wiki/Total_variation_distance_of_probability_measures
"""
return (Tensor([0.5]) * torch.abs(pred - label)).sum()
def jsd(p: List, q: List) -> float:
"""
Jensen-Shannon Divergence (JSD) between two probability distributions as square of scipy's JS distance. Refs:
- https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.jensenshannon.html
- https://stackoverflow.com/questions/15880133/jensen-shannon-divergence
"""
return jensenshannon(p, q) ** 2
|