Spaces:
Running
Running
File size: 2,310 Bytes
033bd8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
def loss_generator(ignore: list = None):
loss_fn = {'mse': mse,
'lut_mse': lut_mse,
'masked_mse': masked_mse,
'sample_weighted_mse': sample_weighted_mse,
'regularize_LUT': regularize_LUT,
'MaskWeightedMSE': MaskWeightedMSE}
if ignore:
for fn in ignore:
ignore.pop(fn)
return loss_fn
def mse(pred, gt):
return torch.mean((pred - gt) ** 2)
def masked_mse(pred, gt, mask):
delimin = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda()
# total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))]))
out = torch.sum((mask > 100 / 255.) * (pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))]))
out = out / delimin
return torch.mean(out)
def sample_weighted_mse(pred, gt, mask):
multi_factor = torch.clamp_min(torch.sum(mask, dim=([x for x in range(1, len(mask.shape))])), 100).cuda()
multi_factor = multi_factor / (multi_factor.sum())
# total = torch.sum(torch.ones_like(mask), dim=([x for x in range(1, len(mask.shape))]))
out = torch.mean((pred - gt) ** 2, dim=([x for x in range(1, len(mask.shape))]))
out = out * multi_factor
return torch.sum(out)
def regularize_LUT(lut):
st = lut[lut < 0.]
reg_st = (st ** 2).mean() if min(st.shape) != 0 else 0
lt = lut[lut > 1.]
reg_lt = ((lt - 1.) ** 2).mean() if min(lt.shape) != 0 else 0
return reg_lt + reg_st
def lut_mse(feat, lut_batch):
loss = 0
for id in range(feat.shape[0] // lut_batch):
for i in feat[id * lut_batch: id * lut_batch + lut_batch]:
for j in feat[id * lut_batch: id * lut_batch + lut_batch]:
loss += mse(i, j)
return loss / lut_batch
def MaskWeightedMSE(pred, label, mask):
label = label.view(pred.size())
reduce_dims = get_dims_with_exclusion(label.dim(), 0)
loss = (pred - label) ** 2
delimeter = pred.size(1) * torch.clamp_min(torch.sum(mask, dim=reduce_dims), 100)
loss = torch.sum(loss, dim=reduce_dims) / delimeter
return torch.mean(loss)
def get_dims_with_exclusion(dim, exclude=None):
dims = list(range(dim))
if exclude is not None:
dims.remove(exclude)
return dims |