import tensorflow as tf |
from absl import app, flags, logging |
from absl.flags import FLAGS |
from core.yolov4 import YOLO, decode, filter_boxes |
import core.utils as utils |
from core.config import cfg |
flags.DEFINE_string('weights', './data/yolov4.weights', 'path to weights file') |
flags.DEFINE_string('output', './checkpoints/yolov4-416', 'path to output') |
flags.DEFINE_boolean('tiny', False, 'is yolo-tiny or not') |
flags.DEFINE_integer('input_size', 416, 'define input size of export model') |
flags.DEFINE_float('score_thres', 0.2, 'define score threshold') |
flags.DEFINE_string('framework', 'tf', 'define what framework do you want to convert (tf, trt, tflite)') |
flags.DEFINE_string('model', 'yolov4', 'yolov3 or yolov4') |
def save_tf(): |
input_layer = tf.keras.layers.Input([FLAGS.input_size, FLAGS.input_size, 3]) |
feature_maps = YOLO(input_layer, NUM_CLASS, FLAGS.model, FLAGS.tiny) |
bbox_tensors = [] |
prob_tensors = [] |
if FLAGS.tiny: |
for i, fm in enumerate(feature_maps): |
if i == 0: |
output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) |
else: |
output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) |
bbox_tensors.append(output_tensors[0]) |
prob_tensors.append(output_tensors[1]) |
else: |
for i, fm in enumerate(feature_maps): |
if i == 0: |
output_tensors = decode(fm, FLAGS.input_size // 8, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) |
elif i == 1: |
output_tensors = decode(fm, FLAGS.input_size // 16, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) |
else: |
output_tensors = decode(fm, FLAGS.input_size // 32, NUM_CLASS, STRIDES, ANCHORS, i, XYSCALE, FLAGS.framework) |
bbox_tensors.append(output_tensors[0]) |
prob_tensors.append(output_tensors[1]) |
pred_bbox = tf.concat(bbox_tensors, axis=1) |
pred_prob = tf.concat(prob_tensors, axis=1) |
if FLAGS.framework == 'tflite': |
pred = (pred_bbox, pred_prob) |
else: |
boxes, pred_conf = filter_boxes(pred_bbox, pred_prob, score_threshold=FLAGS.score_thres, input_shape=tf.constant([FLAGS.input_size, FLAGS.input_size])) |
pred = tf.concat([boxes, pred_conf], axis=-1) |
model = tf.keras.Model(input_layer, pred) |
utils.load_weights(model, FLAGS.weights, FLAGS.model, FLAGS.tiny) |
model.summary() |
model.save(FLAGS.output) |
def main(_argv): |
save_tf() |
if __name__ == '__main__': |
try: |
app.run(main) |
except SystemExit: |
pass |