Spaces:
Running
Running
import torch | |
def loss_generator(ignore: list = None): | |
loss_fn = {'mse': mse, | |
'lut_mse': lut_mse, | |
'masked_mse': masked_mse, | |
'sample_weighted_mse': sample_weighted_mse, | |
'regularize_LUT': regularize_LUT, | |
'MaskWeightedMSE': MaskWeightedMSE} | |
if ignore: | |
for fn in ignore: | |
ignore.pop(fn) | |
return loss_fn | |
def mse(pred, gt): | |
return torch.mean((pred - gt) ** 2) | |
def masked_mse(pred, gt, mask): | |
delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() | |
# total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) | |
out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) | |
out = out / delimin | |
return torch.mean(out) | |
def sample_weighted_mse(pred, gt, mask): | |
multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda() | |
multi_factor = multi_factor / (multi_factor.sum()) | |
# total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))])) | |
out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))])) | |
out = out * multi_factor | |
return torch.sum(out) | |
def regularize_LUT(lut): | |
st = lut[lut < 0.] | |
reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0 | |
lt = lut[lut > 1.] | |
reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0 | |
return reg_lt + reg_st | |
def lut_mse(feat, lut_batch): | |
loss = 0 | |
for id in range(feat.shape[0] // lut_batch): | |
for i in feat[id * lut_batch: id * lut_batch + lut_batch]: | |
for j in feat[id * lut_batch: id * lut_batch + lut_batch]: | |
loss += mse(i, j) | |
return loss / lut_batch | |
def MaskWeightedMSE(pred, label, mask): | |
label = label.view(pred.size()) | |
reduce_dims = get_dims_with_exclusion(label.dim(), 0) | |
loss = (pred - label) ** 2 | |
delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100) | |
loss = torch.sum(loss, dim=reduce_dims) / delimeter | |
return torch.mean(loss) | |
def get_dims_with_exclusion(dim, exclude=None): | |
dims = list(range(dim)) | |
if exclude is not None: | |
dims.remove(exclude) | |
return dims |