|
|
|
|
|
import argparse |
|
import os |
|
from typing import Dict, List, Tuple |
|
import torch |
|
from torch import Tensor, nn |
|
|
|
import detectron2.data.transforms as T |
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import get_cfg |
|
from detectron2.data import build_detection_test_loader, detection_utils |
|
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format |
|
from detectron2.export import TracingAdapter, dump_torchscript_IR, scripting_with_instances |
|
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model |
|
from detectron2.modeling.postprocessing import detector_postprocess |
|
from detectron2.projects.point_rend import add_pointrend_config |
|
from detectron2.structures import Boxes |
|
from detectron2.utils.env import TORCH_VERSION |
|
from detectron2.utils.file_io import PathManager |
|
from detectron2.utils.logger import setup_logger |
|
|
|
|
|
def setup_cfg(args): |
|
cfg = get_cfg() |
|
|
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
add_pointrend_config(cfg) |
|
cfg.merge_from_file(args.config_file) |
|
cfg.merge_from_list(args.opts) |
|
cfg.freeze() |
|
return cfg |
|
|
|
|
|
def export_caffe2_tracing(cfg, torch_model, inputs): |
|
from detectron2.export import Caffe2Tracer |
|
|
|
tracer = Caffe2Tracer(cfg, torch_model, inputs) |
|
if args.format == "caffe2": |
|
caffe2_model = tracer.export_caffe2() |
|
caffe2_model.save_protobuf(args.output) |
|
|
|
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) |
|
return caffe2_model |
|
elif args.format == "onnx": |
|
import onnx |
|
|
|
onnx_model = tracer.export_onnx() |
|
onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) |
|
elif args.format == "torchscript": |
|
ts_model = tracer.export_torchscript() |
|
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
|
torch.jit.save(ts_model, f) |
|
dump_torchscript_IR(ts_model, args.output) |
|
|
|
|
|
|
|
def export_scripting(torch_model): |
|
assert TORCH_VERSION >= (1, 8) |
|
fields = { |
|
"proposal_boxes": Boxes, |
|
"objectness_logits": Tensor, |
|
"pred_boxes": Boxes, |
|
"scores": Tensor, |
|
"pred_classes": Tensor, |
|
"pred_masks": Tensor, |
|
"pred_keypoints": torch.Tensor, |
|
"pred_keypoint_heatmaps": torch.Tensor, |
|
} |
|
assert args.format == "torchscript", "Scripting only supports torchscript format." |
|
|
|
class ScriptableAdapterBase(nn.Module): |
|
|
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.model = torch_model |
|
self.eval() |
|
|
|
if isinstance(torch_model, GeneralizedRCNN): |
|
|
|
class ScriptableAdapter(ScriptableAdapterBase): |
|
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: |
|
instances = self.model.inference(inputs, do_postprocess=False) |
|
return [i.get_fields() for i in instances] |
|
|
|
else: |
|
|
|
class ScriptableAdapter(ScriptableAdapterBase): |
|
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: |
|
instances = self.model(inputs) |
|
return [i.get_fields() for i in instances] |
|
|
|
ts_model = scripting_with_instances(ScriptableAdapter(), fields) |
|
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
|
torch.jit.save(ts_model, f) |
|
dump_torchscript_IR(ts_model, args.output) |
|
|
|
return None |
|
|
|
|
|
|
|
def export_tracing(torch_model, inputs): |
|
assert TORCH_VERSION >= (1, 8) |
|
image = inputs[0]["image"] |
|
inputs = [{"image": image}] |
|
|
|
if isinstance(torch_model, GeneralizedRCNN): |
|
|
|
def inference(model, inputs): |
|
|
|
inst = model.inference(inputs, do_postprocess=False)[0] |
|
return [{"instances": inst}] |
|
|
|
else: |
|
inference = None |
|
|
|
traceable_model = TracingAdapter(torch_model, inputs, inference) |
|
|
|
if args.format == "torchscript": |
|
ts_model = torch.jit.trace(traceable_model, (image,)) |
|
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: |
|
torch.jit.save(ts_model, f) |
|
dump_torchscript_IR(ts_model, args.output) |
|
elif args.format == "onnx": |
|
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: |
|
torch.onnx.export(traceable_model, (image,), f, opset_version=11) |
|
logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) |
|
logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) |
|
|
|
if args.format != "torchscript": |
|
return None |
|
if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): |
|
return None |
|
|
|
def eval_wrapper(inputs): |
|
""" |
|
The exported model does not contain the final resize step, which is typically |
|
unused in deployment but needed for evaluation. We add it manually here. |
|
""" |
|
input = inputs[0] |
|
instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] |
|
postprocessed = detector_postprocess(instances, input["height"], input["width"]) |
|
return [{"instances": postprocessed}] |
|
|
|
return eval_wrapper |
|
|
|
|
|
def get_sample_inputs(args): |
|
|
|
if args.sample_image is None: |
|
|
|
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) |
|
first_batch = next(iter(data_loader)) |
|
return first_batch |
|
else: |
|
|
|
original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) |
|
|
|
aug = T.ResizeShortestEdge( |
|
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST |
|
) |
|
height, width = original_image.shape[:2] |
|
image = aug.get_transform(original_image).apply_image(original_image) |
|
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) |
|
|
|
inputs = {"image": image, "height": height, "width": width} |
|
|
|
|
|
sample_inputs = [inputs] |
|
return sample_inputs |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Export a model for deployment.") |
|
parser.add_argument( |
|
"--format", |
|
choices=["caffe2", "onnx", "torchscript"], |
|
help="output format", |
|
default="torchscript", |
|
) |
|
parser.add_argument( |
|
"--export-method", |
|
choices=["caffe2_tracing", "tracing", "scripting"], |
|
help="Method to export models", |
|
default="tracing", |
|
) |
|
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") |
|
parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") |
|
parser.add_argument("--run-eval", action="store_true") |
|
parser.add_argument("--output", help="output directory for the converted model") |
|
parser.add_argument( |
|
"opts", |
|
help="Modify config options using the command-line", |
|
default=None, |
|
nargs=argparse.REMAINDER, |
|
) |
|
args = parser.parse_args() |
|
logger = setup_logger() |
|
logger.info("Command line arguments: " + str(args)) |
|
PathManager.mkdirs(args.output) |
|
|
|
torch._C._jit_set_bailout_depth(1) |
|
|
|
cfg = setup_cfg(args) |
|
|
|
|
|
torch_model = build_model(cfg) |
|
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) |
|
torch_model.eval() |
|
|
|
|
|
sample_inputs = get_sample_inputs(args) |
|
|
|
|
|
if args.export_method == "caffe2_tracing": |
|
exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) |
|
elif args.export_method == "scripting": |
|
exported_model = export_scripting(torch_model) |
|
elif args.export_method == "tracing": |
|
exported_model = export_tracing(torch_model, sample_inputs) |
|
|
|
|
|
if args.run_eval: |
|
assert exported_model is not None, ( |
|
"Python inference is not yet implemented for " |
|
f"export_method={args.export_method}, format={args.format}." |
|
) |
|
logger.info("Running evaluation ... this takes a long time if you export to CPU.") |
|
dataset = cfg.DATASETS.TEST[0] |
|
data_loader = build_detection_test_loader(cfg, dataset) |
|
|
|
evaluator = COCOEvaluator(dataset, output_dir=args.output) |
|
metrics = inference_on_dataset(exported_model, data_loader, evaluator) |
|
print_csv_format(metrics) |
|
|