File size: 5,876 Bytes
88b0dcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
""" 
@Date: 2021/07/17
@description:
"""
import os
import torch
import torch.nn as nn
import datetime


class BaseModule(nn.Module):
    def __init__(self, ckpt_dir=None):
        super().__init__()

        self.ckpt_dir = ckpt_dir

        if ckpt_dir:
            if not os.path.exists(ckpt_dir):
                os.makedirs(ckpt_dir)
            else:
                self.model_lst = [x for x in sorted(os.listdir(self.ckpt_dir)) if x.endswith('.pkl')]

        self.last_model_path = None
        self.best_model_path = None
        self.best_accuracy = -float('inf')
        self.acc_d = {}

    def show_parameter_number(self, logger):
        total = sum(p.numel() for p in self.parameters())
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        logger.info('{} parameter total:{:,}, trainable:{:,}'.format(self._get_name(), total, trainable))

    def load(self, device, logger, optimizer=None, best=False):
        if len(self.model_lst) == 0:
            logger.info('*'*50)
            logger.info("Empty model folder! Using initial weights")
            logger.info('*'*50)
            return 0

        last_model_lst = list(filter(lambda n: '_last_' in n, self.model_lst))
        best_model_lst = list(filter(lambda n: '_best_' in n, self.model_lst))

        if len(last_model_lst) == 0 and len(best_model_lst) == 0:
            logger.info('*'*50)
            ckpt_path = os.path.join(self.ckpt_dir, self.model_lst[0])
            logger.info(f"Load: {ckpt_path}")
            checkpoint = torch.load(ckpt_path, map_location=torch.device(device))
            self.load_state_dict(checkpoint, strict=False)
            logger.info('*'*50)
            return 0

        checkpoint = None
        if len(last_model_lst) > 0:
            self.last_model_path = os.path.join(self.ckpt_dir, last_model_lst[-1])
            checkpoint = torch.load(self.last_model_path, map_location=torch.device(device))
            self.best_accuracy = checkpoint['accuracy']
            self.acc_d = checkpoint['acc_d']

        if len(best_model_lst) > 0:
            self.best_model_path = os.path.join(self.ckpt_dir, best_model_lst[-1])
            best_checkpoint = torch.load(self.best_model_path, map_location=torch.device(device))
            self.best_accuracy = best_checkpoint['accuracy']
            self.acc_d = best_checkpoint['acc_d']
            if best:
                checkpoint = best_checkpoint

        for k in self.acc_d:
            if isinstance(self.acc_d[k], float):
                self.acc_d[k] = {
                    'acc': self.acc_d[k],
                    'epoch': checkpoint['epoch']
                }

        if checkpoint is None:
            logger.error("Invalid checkpoint")
            return

        self.load_state_dict(checkpoint['net'], strict=False)
        if optimizer and not best:  # best的时候使用新的优化器比如从adam->sgd
            logger.info('Load optimizer')
            optimizer.load_state_dict(checkpoint['optimizer'])
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.to(device)

        logger.info('*'*50)
        if best:
            logger.info(f"Lode best: {self.best_model_path}")
        else:
            logger.info(f"Lode last: {self.last_model_path}")

        logger.info(f"Best accuracy: {self.best_accuracy}")
        logger.info(f"Last epoch: {checkpoint['epoch'] + 1}")
        logger.info('*'*50)
        return checkpoint['epoch'] + 1

    def update_acc(self, acc_d, epoch, logger):
        logger.info("-" * 100)
        for k in acc_d:
            if k not in self.acc_d.keys() or acc_d[k] > self.acc_d[k]['acc']:
                self.acc_d[k] = {
                    'acc': acc_d[k],
                    'epoch': epoch
                }
            logger.info(f"Update ACC: {k} {self.acc_d[k]['acc']:.4f}({self.acc_d[k]['epoch']}-{epoch})")
        logger.info("-" * 100)

    def save(self, optim, epoch, accuracy, logger, replace=True, acc_d=None, config=None):
        """

        :param config:
        :param optim:
        :param epoch:
        :param accuracy:
        :param logger:
        :param replace:
        :param acc_d: 其他评估数据,visible_2/3d, full_2/3d, rmse...
        :return:
        """
        if acc_d:
            self.update_acc(acc_d, epoch, logger)
        name = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S_last_{:.4f}_{}'.format(accuracy, epoch))
        name = f"model_{name}.pkl"
        checkpoint = {
            'net': self.state_dict(),
            'optimizer': optim.state_dict(),
            'epoch': epoch,
            'accuracy': accuracy,
            'acc_d': acc_d
        }
        # FIXME:: delete always true
        if (True or config.MODEL.SAVE_LAST) and epoch % config.TRAIN.SAVE_FREQ == 0:
            if replace and self.last_model_path and os.path.exists(self.last_model_path):
                os.remove(self.last_model_path)
            self.last_model_path = os.path.join(self.ckpt_dir, name)
            torch.save(checkpoint, self.last_model_path)
            logger.info(f"Saved last model: {self.last_model_path}")

        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            # FIXME:: delete always true
            if True or config.MODEL.SAVE_BEST:
                if replace and self.best_model_path and os.path.exists(self.best_model_path):
                    os.remove(self.best_model_path)
                self.best_model_path = os.path.join(self.ckpt_dir, name.replace('last', 'best'))
                torch.save(checkpoint, self.best_model_path)
                logger.info("#" * 100)
                logger.info(f"Saved best model: {self.best_model_path}")
                logger.info("#" * 100)