import torch import torch.nn as nn DEV = torch.device('cuda:0') def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''): if type(module) in layers: return {name: module} res = {} for name1, child in module.named_children(): res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1)) return res def gen_conditions(_wbits, _groupsize): wbits = _wbits groupsize = _groupsize conditions = [] while True: if wbits >= 8: if groupsize == -1 or groupsize == 32: break if groupsize > 32: groupsize /= 2 else: wbits *= 2 groupsize = _groupsize conditions.append((int(wbits), int(groupsize))) return conditions # copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor: """ Compute SNR between y_pred(tensor) and y_real(tensor) SNR can be calcualted as following equation: SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) Args: y_pred (torch.Tensor): _description_ y_real (torch.Tensor): _description_ reduction (str, optional): _description_. Defaults to 'mean'. Raises: ValueError: _description_ ValueError: _description_ Returns: torch.Tensor: _description_ """ y_pred = y_pred.type(torch.float32) y_real = y_real.type(torch.float32) if y_pred.shape != y_real.shape: raise ValueError(f'Can not compute snr loss for tensors with different shape. ' f'({y_pred.shape} and {y_real.shape})') reduction = str(reduction).lower() if y_pred.ndim == 1: y_pred = y_pred.unsqueeze(0) y_real = y_real.unsqueeze(0) y_pred = y_pred.flatten(start_dim=1) y_real = y_real.flatten(start_dim=1) noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) signal_power = torch.pow(y_real, 2).sum(dim=-1) snr = (noise_power) / (signal_power + 1e-7) if reduction == 'mean': return torch.mean(snr) elif reduction == 'sum': return torch.sum(snr) elif reduction == 'none': return snr else: raise ValueError(f'Unsupported reduction method.')