|
import tensorflow as tf |
|
from absl import app, flags, logging |
|
from absl.flags import FLAGS |
|
import numpy as np |
|
import cv2 |
|
from core.yolov4 import YOLOv4, YOLOv3, YOLOv3_tiny, decode |
|
import core.utils as utils |
|
import os |
|
from core.config import cfg |
|
|
|
flags.DEFINE_string('weights', './checkpoints/yolov4-416', 'path to weights file') |
|
flags.DEFINE_string('output', './checkpoints/yolov4-416-fp32.tflite', 'path to output') |
|
flags.DEFINE_integer('input_size', 416, 'path to output') |
|
flags.DEFINE_string('quantize_mode', 'float32', 'quantize mode (int8, float16, float32)') |
|
flags.DEFINE_string('dataset', "/Volumes/Elements/data/coco_dataset/coco/5k.txt", 'path to dataset') |
|
|
|
def representative_data_gen(): |
|
fimage = open(FLAGS.dataset).read().split() |
|
for input_value in range(10): |
|
if os.path.exists(fimage[input_value]): |
|
original_image=cv2.imread(fimage[input_value]) |
|
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) |
|
image_data = utils.image_preprocess(np.copy(original_image), [FLAGS.input_size, FLAGS.input_size]) |
|
img_in = image_data[np.newaxis, ...].astype(np.float32) |
|
print("calibration image {}".format(fimage[input_value])) |
|
yield [img_in] |
|
else: |
|
continue |
|
|
|
def save_tflite(): |
|
converter = tf.lite.TFLiteConverter.from_saved_model(FLAGS.weights) |
|
|
|
if FLAGS.quantize_mode == 'float16': |
|
converter.optimizations = [tf.lite.Optimize.DEFAULT] |
|
converter.target_spec.supported_types = [tf.compat.v1.lite.constants.FLOAT16] |
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] |
|
converter.allow_custom_ops = True |
|
elif FLAGS.quantize_mode == 'int8': |
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] |
|
converter.optimizations = [tf.lite.Optimize.DEFAULT] |
|
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS] |
|
converter.allow_custom_ops = True |
|
converter.representative_dataset = representative_data_gen |
|
|
|
tflite_model = converter.convert() |
|
open(FLAGS.output, 'wb').write(tflite_model) |
|
|
|
logging.info("model saved to: {}".format(FLAGS.output)) |
|
|
|
def demo(): |
|
interpreter = tf.lite.Interpreter(model_path=FLAGS.output) |
|
interpreter.allocate_tensors() |
|
logging.info('tflite model loaded') |
|
|
|
input_details = interpreter.get_input_details() |
|
print(input_details) |
|
output_details = interpreter.get_output_details() |
|
print(output_details) |
|
|
|
input_shape = input_details[0]['shape'] |
|
|
|
input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32) |
|
|
|
interpreter.set_tensor(input_details[0]['index'], input_data) |
|
interpreter.invoke() |
|
output_data = [interpreter.get_tensor(output_details[i]['index']) for i in range(len(output_details))] |
|
|
|
print(output_data) |
|
|
|
def main(_argv): |
|
save_tflite() |
|
demo() |
|
|
|
if __name__ == '__main__': |
|
try: |
|
app.run(main) |
|
except SystemExit: |
|
pass |
|
|
|
|
|
|