kevinwang676's picture
Upload 93 files
9016314 verified
raw
history blame
No virus
4.35 kB
import logging
from pprint import pformat
import numpy as np
import tensorflow as tf
class BaseModel(object):
def __init__(self, hps):
super(BaseModel, self).__init__()
self.hps = hps
g = tf.Graph()
with g.as_default():
# open a session
config = tf.compat.v1.ConfigProto()
config.log_device_placement = True
config.allow_soft_placement = True
config.gpu_options.allow_growth = True
self.sess = tf.compat.v1.Session(config=config, graph=g)
# build model
self.build_net()
self.build_ops()
# initialize
self.sess.run(tf.compat.v1.global_variables_initializer())
self.saver = tf.compat.v1.train.Saver()
self.writer = tf.compat.v1.summary.FileWriter(self.hps.exp_dir + '/summary')
# logging
total_params = 0
trainable_variables = tf.compat.v1.trainable_variables()
logging.info('=' * 20)
logging.info("Variables:")
logging.info(pformat(trainable_variables))
for v in trainable_variables:
num_params = np.prod(v.get_shape().as_list())
total_params += num_params
logging.info("TOTAL TENSORS: %d TOTAL PARAMS: %f[M]" % (
len(trainable_variables), total_params / 1e6))
logging.info('=' * 20)
def save(self, filename='params'):
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
self.saver.save(self.sess, fname)
def load(self, filename='params'):
fname = f'{self.hps.exp_dir}/weights/{filename}.ckpt'
self.saver.restore(self.sess, fname)
def build_net(self):
raise NotImplementedError()
def build_ops(self):
# optimizer
self.global_step = tf.compat.v1.train.get_or_create_global_step()
learning_rate = tf.compat.v1.train.inverse_time_decay(
self.hps.lr, self.global_step,
self.hps.decay_steps, self.hps.decay_rate,
staircase=True)
warmup_lr = tf.compat.v1.train.inverse_time_decay(
0.001 * self.hps.lr, self.global_step,
self.hps.decay_steps, self.hps.decay_rate,
staircase=True)
learning_rate = tf.cond(pred=tf.less(self.global_step, 1000), true_fn=lambda: warmup_lr, false_fn=lambda: learning_rate)
tf.compat.v1.summary.scalar('lr', learning_rate)
if self.hps.optimizer == 'adam':
optimizer = tf.compat.v1.train.AdamOptimizer(
learning_rate=learning_rate,
beta1=0.9, beta2=0.999, epsilon=1e-08,
use_locking=False, name="Adam")
elif self.hps.optimizer == 'rmsprop':
optimizer = tf.compat.v1.train.RMSPropOptimizer(
learning_rate=learning_rate)
elif self.hps.optimizer == 'mom':
optimizer = tf.compat.v1.train.MomentumOptimizer(
learning_rate=learning_rate,
momentum=0.9)
else:
optimizer = tf.compat.v1.train.GradientDescentOptimizer(
learning_rate=learning_rate)
# regularization
l2_reg = sum(
[tf.reduce_sum(input_tensor=tf.square(v)) for v in tf.compat.v1.trainable_variables()
if ("magnitude" in v.name) or ("rescaling_scale" in v.name)])
reg_loss = 0.00005 * l2_reg
# train
grads_and_vars = optimizer.compute_gradients(
self.loss+reg_loss, tf.compat.v1.trainable_variables())
grads, vars_ = zip(*grads_and_vars)
if self.hps.clip_gradient > 0:
grads, gradient_norm = tf.clip_by_global_norm(
grads, clip_norm=self.hps.clip_gradient)
gradient_norm = tf.debugging.check_numerics(
gradient_norm, "Gradient norm is NaN or Inf.")
tf.compat.v1.summary.scalar('gradient_norm', gradient_norm)
capped_grads_and_vars = zip(grads, vars_)
self.train_op = optimizer.apply_gradients(
capped_grads_and_vars, global_step=self.global_step)
# summary
self.summ_op = tf.compat.v1.summary.merge_all()
def execute(self, cmd, batch):
return self.sess.run(cmd, {self.x:batch['x'], self.b:batch['b'], self.m:batch['m']})