File size: 4,353 Bytes
9016314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import logging
from pprint import pformat
import numpy as np
import tensorflow as tf
class BaseModel(object):
    def __init__(self, hps):
        super(BaseModel, self).__init__()

        self.hps = hps
        g = tf.Graph()
        with g.as_default():
            # open a session
            config = tf.compat.v1.ConfigProto()
            config.log_device_placement = True
            config.allow_soft_placement = True
            config.gpu_options.allow_growth = True
            self.sess = tf.compat.v1.Session(config=config, graph=g)
            # build model
            self.build_net()
            self.build_ops()
            # initialize
            self.sess.run(tf.compat.v1.global_variables_initializer())
            self.saver = tf.compat.v1.train.Saver()
            self.writer = tf.compat.v1.summary.FileWriter(self.hps.exp_dir + '/summary')
            # logging
            total_params = 0
            trainable_variables = tf.compat.v1.trainable_variables()
            logging.info('=' * 20)
            logging.info("Variables:")
            logging.info(pformat(trainable_variables))
            for v in trainable_variables:
                num_params = np.prod(v.get_shape().as_list())
                total_params += num_params

            logging.info("TOTAL TENSORS: %d TOTAL PARAMS: %f[M]" % (
                len(trainable_variables), total_params / 1e6))
            logging.info('=' * 20)

    def save(self, filename='params'):
        fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
        self.saver.save(self.sess, fname)

    def load(self, filename='params'):
        fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
        self.saver.restore(self.sess, fname)

    def build_net(self):
        raise NotImplementedError()

    def build_ops(self):
        # optimizer
        self.global_step = tf.compat.v1.train.get_or_create_global_step()
        learning_rate = tf.compat.v1.train.inverse_time_decay(
            self.hps.lr, self.global_step,
            self.hps.decay_steps, self.hps.decay_rate,
            staircase=True)
        warmup_lr = tf.compat.v1.train.inverse_time_decay(
            0.001 * self.hps.lr, self.global_step,
            self.hps.decay_steps, self.hps.decay_rate,
            staircase=True)
        learning_rate = tf.cond(pred=tf.less(self.global_step, 1000), true_fn=lambda: warmup_lr, false_fn=lambda: learning_rate)
        tf.compat.v1.summary.scalar('lr', learning_rate)
        if self.hps.optimizer == 'adam':
            optimizer = tf.compat.v1.train.AdamOptimizer(
                learning_rate=learning_rate,
                beta1=0.9, beta2=0.999, epsilon=1e-08,
                use_locking=False, name="Adam")
        elif self.hps.optimizer == 'rmsprop':
            optimizer = tf.compat.v1.train.RMSPropOptimizer(
                learning_rate=learning_rate)
        elif self.hps.optimizer == 'mom':
            optimizer = tf.compat.v1.train.MomentumOptimizer(
                learning_rate=learning_rate,
                momentum=0.9)
        else:
            optimizer = tf.compat.v1.train.GradientDescentOptimizer(
                learning_rate=learning_rate)

        # regularization
        l2_reg = sum(
                [tf.reduce_sum(input_tensor=tf.square(v)) for v in tf.compat.v1.trainable_variables()
                 if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
        reg_loss = 0.00005 * l2_reg

        # train
        grads_and_vars = optimizer.compute_gradients(
            self.loss+reg_loss, tf.compat.v1.trainable_variables())
        grads, vars_ = zip(*grads_and_vars)
        if self.hps.clip_gradient > 0:
            grads, gradient_norm = tf.clip_by_global_norm(
                grads, clip_norm=self.hps.clip_gradient)
            gradient_norm = tf.debugging.check_numerics(
                gradient_norm, "Gradient norm is NaN or Inf.")
            tf.compat.v1.summary.scalar('gradient_norm', gradient_norm)
        capped_grads_and_vars = zip(grads, vars_)
        self.train_op = optimizer.apply_gradients(
            capped_grads_and_vars, global_step=self.global_step)
        
        # summary
        self.summ_op = tf.compat.v1.summary.merge_all()

    def execute(self, cmd, batch):
        return self.sess.run(cmd, {self.x:batch['x'], self.b:batch['b'], self.m:batch['m']})