File size: 2,811 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import utils
from collections import OrderedDict

from .abs_model import abs_model
from .blocks import *
from .Loss.Loss import avg_norm_loss

class Template(abs_model):
    """ Standard Unet Implementation
        src: https://arxiv.org/pdf/1505.04597.pdf
    """
    def __init__(self, opt):
        resunet      = opt['model']['resunet']
        out_act      = opt['model']['out_act']
        norm_type    = opt['model']['norm_type']
        in_channels  = opt['model']['in_channels']
        out_channels = opt['model']['out_channels']
        self.ncols   = opt['hyper_params']['n_cols']

        self.model = Unet(in_channels=in_channels,
                          out_channels=out_channels,
                          norm_type=norm_type,
                          out_act=out_act,
                          resunet=resunet)

        self.optimizer = get_optimizer(opt, self.model)
        self.visualization = {}


    def setup_input(self, x):
        return x


    def forward(self, x):
        return self.model(x)


    def compute_loss(self, y, pred):
        return avg_norm_loss(y, pred)


    def supervise(self, input_x, y, is_training:bool)->float:
        optimizer = self.optimizer
        model = self.model

        optimizer.zero_grad()
        pred = model(input_x)
        loss = self.compute_loss(y, pred)

        if is_training:
            loss.backward()
            optimizer.step()

        self.visualization['y']    = pred.detach()
        self.visualization['pred'] = pred.detach()

        return loss.item()


    def get_visualize(self) -> OrderedDict:
        """ Convert to visualization numpy array
        """
        nrows          = self.ncols
        visualizations = self.visualization
        ret_vis        = OrderedDict()

        for k, v in visualizations.items():
            batch = v.shape[0]
            n     = min(nrows, batch)

            plot_v = v[:n]
            ret_vis[k] = utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0)

        return ret_vis


    def inference(self, x):
        # TODO
        pass


    def batch_inference(self, x):
        # TODO
        pass


    """ Getter & Setter
    """
    def get_models(self) -> dict:
        return {'model': self.model}


    def get_optimizers(self) -> dict:
        return {'optimizer': self.optimizer}


    def set_models(self, models: dict) :
        # input test
        if 'model' not in models.keys():
            raise ValueError('{} not in self.model'.format('model'))

        self.model = models['model']


    def set_optimizers(self, optimizer: dict):
        self.optimizer = optimizer['optimizer']


    ####################
    # Personal Methods #
    ####################