File size: 3,868 Bytes
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e7060e
 
34fb220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import utils
from collections import OrderedDict
import numpy as np

from .abs_model import abs_model
from .Loss.Loss import norm_loss
from .blocks import *
from .SSN_Model import SSN_Model


class SSN(abs_model):
    def __init__(self, opt):
        mid_act      = opt['model']['mid_act']
        out_act      = opt['model']['out_act']
        in_channels  = opt['model']['in_channels']
        out_channels = opt['model']['out_channels']
        self.ncols   = opt['hyper_params']['n_cols']

        self.model         = SSN_Model(in_channels=in_channels, out_channels=out_channels, mid_act=mid_act, out_act=out_act)
        self.optimizer     = get_optimizer(opt, self.model)
        self.visualization = {}

        self.norm_loss_ = norm_loss(norm=1)

    def setup_input(self, x):
        return x


    def forward(self, x):
        keys = ['mask', 'ibl']

        for k in keys:
            assert k in x.keys(), '{} not in input'.format(k)

        mask = x['mask']
        ibl  = x['ibl']

        return self.model(mask, ibl)


    def compute_loss(self, y, pred):
        total_loss = self.norm_loss_.loss(y, pred)
        return total_loss


    def supervise(self, input_x, y, is_training:bool)->float:
        optimizer = self.optimizer
        model = self.model

        optimizer.zero_grad()
        pred = self.forward(input_x)
        loss = self.compute_loss(y, pred)

        # logging.info('Pred/Target: {}, {}/{}, {}'.format(pred.min().item(), pred.max().item(), y.min().item(), y.max().item()))

        if is_training:
            loss.backward()
            optimizer.step()

        self.visualization['mask'] = input_x['mask'].detach()
        self.visualization['ibl'] = input_x['ibl'].detach()
        self.visualization['y']    = y.detach()
        self.visualization['pred'] = pred.detach()

        return loss.item()


    def get_visualize(self) -> OrderedDict:
        """ Convert to visualization numpy array
        """
        nrows          = self.ncols
        visualizations = self.visualization
        ret_vis        = OrderedDict()

        for k, v in visualizations.items():
            batch = v.shape[0]
            n     = min(nrows, batch)

            plot_v = v[:n]
            plot_v = (plot_v - plot_v.min())/(plot_v.max() - plot_v.min())
            ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0)

        return ret_vis


    def get_logs(self):
        pass


    def inference(self, x):
        keys = ['mask', 'ibl']
        for k in keys:
            assert k in x.keys(), '{} not in input'.format(k)
            assert len(x[k].shape) == 2, '{} should be 2D tensor'.format(k)


        # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        device = torch.device('cpu')

        mask = torch.tensor(x['mask'])[None, None, ...].float().to(device)
        ibl  = torch.tensor(x['ibl'])[None, None, ...].float().to(device)

        input_x = {'mask': mask, 'ibl': ibl}
        pred    = self.forward(input_x)

        pred = np.clip(pred[0, 0].detach().cpu().numpy() / 30.0, 0.0, 1.0)
        return pred



    def batch_inference(self, x):
        # TODO
        pass


    """ Getter & Setter
    """
    def get_models(self) -> dict:
        return {'model': self.model}


    def get_optimizers(self) -> dict:
        return {'optimizer': self.optimizer}


    def set_models(self, models: dict) :
        # input test
        if 'model' not in models.keys():
            raise ValueError('{} not in self.model'.format('model'))

        self.model = models['model']


    def set_optimizers(self, optimizer: dict):
        self.optimizer = optimizer['optimizer']

    ####################
    # Personal Methods #
    ####################