|
import torch |
|
import torch.nn as nn |
|
|
|
DEV = torch.device('cuda:0') |
|
|
|
|
|
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): |
|
if type(module) in layers: |
|
return {name: module} |
|
res = {} |
|
for name1, child in module.named_children(): |
|
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) |
|
return res |
|
|
|
|
|
def gen_conditions(_wbits, _groupsize): |
|
wbits = _wbits |
|
groupsize = _groupsize |
|
conditions = [] |
|
while True: |
|
if wbits >= 8: |
|
if groupsize == -1 or groupsize == 32: |
|
break |
|
|
|
if groupsize > 32: |
|
groupsize /= 2 |
|
else: |
|
wbits *= 2 |
|
groupsize = _groupsize |
|
|
|
conditions.append((int(wbits), int(groupsize))) |
|
return conditions |
|
|
|
|
|
|
|
def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: |
|
""" |
|
Compute SNR between y_pred(tensor) and y_real(tensor) |
|
|
|
SNR can be calcualted as following equation: |
|
|
|
SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 |
|
|
|
if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. |
|
|
|
SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) |
|
Args: |
|
y_pred (torch.Tensor): _description_ |
|
y_real (torch.Tensor): _description_ |
|
reduction (str, optional): _description_. Defaults to 'mean'. |
|
Raises: |
|
ValueError: _description_ |
|
ValueError: _description_ |
|
Returns: |
|
torch.Tensor: _description_ |
|
""" |
|
y_pred = y_pred.type(torch.float32) |
|
y_real = y_real.type(torch.float32) |
|
|
|
if y_pred.shape != y_real.shape: |
|
raise ValueError(f'Can not compute snr loss for tensors with different shape. ' |
|
f'({y_pred.shape} and {y_real.shape})') |
|
reduction = str(reduction).lower() |
|
|
|
if y_pred.ndim == 1: |
|
y_pred = y_pred.unsqueeze(0) |
|
y_real = y_real.unsqueeze(0) |
|
|
|
y_pred = y_pred.flatten(start_dim=1) |
|
y_real = y_real.flatten(start_dim=1) |
|
|
|
noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) |
|
signal_power = torch.pow(y_real, 2).sum(dim=-1) |
|
snr = (noise_power) / (signal_power + 1e-7) |
|
|
|
if reduction == 'mean': |
|
return torch.mean(snr) |
|
elif reduction == 'sum': |
|
return torch.sum(snr) |
|
elif reduction == 'none': |
|
return snr |
|
else: |
|
raise ValueError(f'Unsupported reduction method.') |
|
|