yichen-purdue's picture
init
34fb220
raw
history blame
No virus
5.31 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import utils
from torchvision.transforms import Resize
from collections import OrderedDict
import numpy as np
import matplotlib.cm as cm
import matplotlib as mpl
from torchvision.transforms import InterpolationMode
from .abs_model import abs_model
from .blocks import *
from .SSN import SSN
from .SSN_v1 import SSN_v1
from .Loss.Loss import norm_loss, grad_loss
from .Attention_Unet import Attention_Unet
class Sparse_PH(abs_model):
def __init__(self, opt):
mid_act = opt['model']['mid_act']
out_act = opt['model']['out_act']
in_channels = opt['model']['in_channels']
out_channels = opt['model']['out_channels']
resnet = opt['model']['resnet']
backbone = opt['model']['backbone']
self.ncols = opt['hyper_params']['n_cols']
self.focal = opt['model']['focal']
self.clip = opt['model']['clip']
self.norm_loss_ = opt['model']['norm_loss']
self.grad_loss_ = opt['model']['grad_loss']
self.ggrad_loss_ = opt['model']['ggrad_loss']
self.lap_loss = opt['model']['lap_loss']
self.clip_range = opt['dataset']['linear_scale'] + opt['dataset']['linear_offset']
if backbone == 'Default':
self.model = SSN_v1(in_channels=in_channels,
out_channels=out_channels,
mid_act=mid_act,
out_act=out_act,
resnet=resnet)
elif backbone == 'ATTN':
self.model = Attention_Unet(in_channels, out_channels, mid_act=mid_act, out_act=out_act)
self.optimizer = get_optimizer(opt, self.model)
self.visualization = {}
self.norm_loss = norm_loss()
self.grad_loss = grad_loss()
def setup_input(self, x):
return x
def forward(self, x):
return self.model(x)
def compute_loss(self, y, pred):
b = y.shape[0]
# total_loss = avg_norm_loss(y, pred)
nloss = self.norm_loss.loss(y, pred) * self.norm_loss_
gloss = self.grad_loss.loss(pred) * self.grad_loss_
ggloss = self.grad_loss.gloss(y, pred) * self.ggrad_loss_
laploss = self.grad_loss.laploss(pred) * self.lap_loss
total_loss = nloss + gloss + ggloss + laploss
self.loss_log = {
'norm_loss': nloss.item(),
'grad_loss': gloss.item(),
'grad_l1_loss': ggloss.item(),
'lap_loss': laploss.item(),
}
if self.focal:
total_loss = torch.pow(total_loss, 3)
return total_loss
def supervise(self, input_x, y, is_training:bool)->float:
optimizer = self.optimizer
model = self.model
x = input_x['x']
optimizer.zero_grad()
pred = self.forward(x)
if self.clip:
pred = torch.clip(pred, 0.0, self.clip_range)
loss = self.compute_loss(y, pred)
if is_training:
loss.backward()
optimizer.step()
xc = x.shape[1]
for i in range(xc):
self.visualization['x{}'.format(i)] = x[:, i:i+1].detach()
self.visualization['y_fore'] = y[:, 0:1].detach()
self.visualization['y_back'] = y[:, 1:2].detach()
self.visualization['pred_fore'] = pred[:, 0:1].detach()
self.visualization['pred_back'] = pred[:, 1:2].detach()
return loss.item()
def get_visualize(self) -> OrderedDict:
""" Convert to visualization numpy array
"""
nrows = self.ncols
visualizations = self.visualization
ret_vis = OrderedDict()
for k, v in visualizations.items():
batch = v.shape[0]
n = min(nrows, batch)
plot_v = v[:n]
ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)
ret_vis[k] = self.plasma(ret_vis[k])
return ret_vis
def get_logs(self):
return self.loss_log
def inference(self, x):
x, device = x['x'], x['device']
x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).float().to(device)
pred = self.forward(x)
pred = pred[0].detach().cpu().numpy().transpose((1,2,0))
return pred
def batch_inference(self, x):
x = x['x']
pred = self.forward(x)
return pred
""" Getter & Setter
"""
def get_models(self) -> dict:
return {'model': self.model}
def get_optimizers(self) -> dict:
return {'optimizer': self.optimizer}
def set_models(self, models: dict) :
# input test
if 'model' not in models.keys():
raise ValueError('{} not in self.model'.format('model'))
self.model = models['model']
def set_optimizers(self, optimizer: dict):
self.optimizer = optimizer['optimizer']
####################
# Personal Methods #
####################
def plasma(self, x):
norm = mpl.colors.Normalize(vmin=0.0, vmax=1)
mapper = cm.ScalarMappable(norm=norm, cmap='plasma')
bimg = mapper.to_rgba(x[:,:,0])[:,:,:3]
return bimg