papasega's picture
Update modelutil.py
b79b2d0 verified
raw
history blame
1.31 kB
# import tensorflow as tf
# def create_model():
# LAYERS = [tf.keras.layers.Flatten(input_shape=[28,28], name="inputlayer"),
# tf.keras.layers.Dense(300, activation='relu', name="hiddenlayer1"),
# tf.keras.layers.Dense(100, activation='relu', name="hiddenlayer2"),
# tf.keras.layers.Dense(10, activation='softmax', name="outputlayer")]
# model = tf.keras.models.Sequential(LAYERS)
# model.load_weights('./checkpoint')
# # LOSS_FUNCTION = tf.keras.losses.SparseCategoricalCrossentropy() # HERE
# # OPTIMIZER = tf.keras.optimizers.legacy.Adam()
# # METRICS = ["accuracy"]
# # model.compile(loss=LOSS_FUNCTION,
# # optimizer=OPTIMIZER,
# # metrics=METRICS)
# return model
import tensorflow as tf
def create_model():
LAYERS = [tf.keras.layers.Flatten(input_shape=[28,28], name="inputlayer"),
tf.keras.layers.Dense(300, activation='relu', name="hiddenlayer1"),
tf.keras.layers.Dense(100, activation='relu', name="hiddenlayer2"),
tf.keras.layers.Dense(10, activation='softmax', name="outputlayer")]
model = tf.keras.models.Sequential(LAYERS)
return model
def load_model_weights(model, checkpoint_path):
model.load_weights(checkpoint_path)