Spaces:
Runtime error
Runtime error
File size: 2,811 Bytes
34fb220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import utils
from collections import OrderedDict
from .abs_model import abs_model
from .blocks import *
from .Loss.Loss import avg_norm_loss
class Template(abs_model):
""" Standard Unet Implementation
src: https://arxiv.org/pdf/1505.04597.pdf
"""
def __init__(self, opt):
resunet = opt['model']['resunet']
out_act = opt['model']['out_act']
norm_type = opt['model']['norm_type']
in_channels = opt['model']['in_channels']
out_channels = opt['model']['out_channels']
self.ncols = opt['hyper_params']['n_cols']
self.model = Unet(in_channels=in_channels,
out_channels=out_channels,
norm_type=norm_type,
out_act=out_act,
resunet=resunet)
self.optimizer = get_optimizer(opt, self.model)
self.visualization = {}
def setup_input(self, x):
return x
def forward(self, x):
return self.model(x)
def compute_loss(self, y, pred):
return avg_norm_loss(y, pred)
def supervise(self, input_x, y, is_training:bool)->float:
optimizer = self.optimizer
model = self.model
optimizer.zero_grad()
pred = model(input_x)
loss = self.compute_loss(y, pred)
if is_training:
loss.backward()
optimizer.step()
self.visualization['y'] = pred.detach()
self.visualization['pred'] = pred.detach()
return loss.item()
def get_visualize(self) -> OrderedDict:
""" Convert to visualization numpy array
"""
nrows = self.ncols
visualizations = self.visualization
ret_vis = OrderedDict()
for k, v in visualizations.items():
batch = v.shape[0]
n = min(nrows, batch)
plot_v = v[:n]
ret_vis[k] = utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0)
return ret_vis
def inference(self, x):
# TODO
pass
def batch_inference(self, x):
# TODO
pass
""" Getter & Setter
"""
def get_models(self) -> dict:
return {'model': self.model}
def get_optimizers(self) -> dict:
return {'optimizer': self.optimizer}
def set_models(self, models: dict) :
# input test
if 'model' not in models.keys():
raise ValueError('{} not in self.model'.format('model'))
self.model = models['model']
def set_optimizers(self, optimizer: dict):
self.optimizer = optimizer['optimizer']
####################
# Personal Methods #
####################
|