Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
import argparse | |
import os | |
from typing import Dict, List, Tuple | |
import torch | |
from torch import Tensor, nn | |
import detectron2.data.transforms as T | |
from detectron2.checkpoint import DetectionCheckpointer | |
from detectron2.config import get_cfg | |
from detectron2.data import build_detection_test_loader, detection_utils | |
from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format | |
from detectron2.export import ( | |
STABLE_ONNX_OPSET_VERSION, | |
TracingAdapter, | |
dump_torchscript_IR, | |
scripting_with_instances, | |
) | |
from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model | |
from detectron2.modeling.postprocessing import detector_postprocess | |
from detectron2.projects.point_rend import add_pointrend_config | |
from detectron2.structures import Boxes | |
from detectron2.utils.env import TORCH_VERSION | |
from detectron2.utils.file_io import PathManager | |
from detectron2.utils.logger import setup_logger | |
def setup_cfg(args): | |
cfg = get_cfg() | |
# cuda context is initialized before creating dataloader, so we don't fork anymore | |
cfg.DATALOADER.NUM_WORKERS = 0 | |
add_pointrend_config(cfg) | |
cfg.merge_from_file(args.config_file) | |
cfg.merge_from_list(args.opts) | |
cfg.freeze() | |
return cfg | |
def export_caffe2_tracing(cfg, torch_model, inputs): | |
from detectron2.export import Caffe2Tracer | |
tracer = Caffe2Tracer(cfg, torch_model, inputs) | |
if args.format == "caffe2": | |
caffe2_model = tracer.export_caffe2() | |
caffe2_model.save_protobuf(args.output) | |
# draw the caffe2 graph | |
caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=inputs) | |
return caffe2_model | |
elif args.format == "onnx": | |
import onnx | |
onnx_model = tracer.export_onnx() | |
onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) | |
elif args.format == "torchscript": | |
ts_model = tracer.export_torchscript() | |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
torch.jit.save(ts_model, f) | |
dump_torchscript_IR(ts_model, args.output) | |
# experimental. API not yet final | |
def export_scripting(torch_model): | |
assert TORCH_VERSION >= (1, 8) | |
fields = { | |
"proposal_boxes": Boxes, | |
"objectness_logits": Tensor, | |
"pred_boxes": Boxes, | |
"scores": Tensor, | |
"pred_classes": Tensor, | |
"pred_masks": Tensor, | |
"pred_keypoints": torch.Tensor, | |
"pred_keypoint_heatmaps": torch.Tensor, | |
} | |
assert args.format == "torchscript", "Scripting only supports torchscript format." | |
class ScriptableAdapterBase(nn.Module): | |
# Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 | |
# by not retuning instances but dicts. Otherwise the exported model is not deployable | |
def __init__(self): | |
super().__init__() | |
self.model = torch_model | |
self.eval() | |
if isinstance(torch_model, GeneralizedRCNN): | |
class ScriptableAdapter(ScriptableAdapterBase): | |
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | |
instances = self.model.inference(inputs, do_postprocess=False) | |
return [i.get_fields() for i in instances] | |
else: | |
class ScriptableAdapter(ScriptableAdapterBase): | |
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: | |
instances = self.model(inputs) | |
return [i.get_fields() for i in instances] | |
ts_model = scripting_with_instances(ScriptableAdapter(), fields) | |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
torch.jit.save(ts_model, f) | |
dump_torchscript_IR(ts_model, args.output) | |
# TODO inference in Python now missing postprocessing glue code | |
return None | |
# experimental. API not yet final | |
def export_tracing(torch_model, inputs): | |
assert TORCH_VERSION >= (1, 8) | |
image = inputs[0]["image"] | |
inputs = [{"image": image}] # remove other unused keys | |
if isinstance(torch_model, GeneralizedRCNN): | |
def inference(model, inputs): | |
# use do_postprocess=False so it returns ROI mask | |
inst = model.inference(inputs, do_postprocess=False)[0] | |
return [{"instances": inst}] | |
else: | |
inference = None # assume that we just call the model directly | |
traceable_model = TracingAdapter(torch_model, inputs, inference) | |
if args.format == "torchscript": | |
ts_model = torch.jit.trace(traceable_model, (image,)) | |
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: | |
torch.jit.save(ts_model, f) | |
dump_torchscript_IR(ts_model, args.output) | |
elif args.format == "onnx": | |
with PathManager.open(os.path.join(args.output, "model.onnx"), "wb") as f: | |
torch.onnx.export(traceable_model, (image,), f, opset_version=STABLE_ONNX_OPSET_VERSION) | |
logger.info("Inputs schema: " + str(traceable_model.inputs_schema)) | |
logger.info("Outputs schema: " + str(traceable_model.outputs_schema)) | |
if args.format != "torchscript": | |
return None | |
if not isinstance(torch_model, (GeneralizedRCNN, RetinaNet)): | |
return None | |
def eval_wrapper(inputs): | |
""" | |
The exported model does not contain the final resize step, which is typically | |
unused in deployment but needed for evaluation. We add it manually here. | |
""" | |
input = inputs[0] | |
instances = traceable_model.outputs_schema(ts_model(input["image"]))[0]["instances"] | |
postprocessed = detector_postprocess(instances, input["height"], input["width"]) | |
return [{"instances": postprocessed}] | |
return eval_wrapper | |
def get_sample_inputs(args): | |
if args.sample_image is None: | |
# get a first batch from dataset | |
data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) | |
first_batch = next(iter(data_loader)) | |
return first_batch | |
else: | |
# get a sample data | |
original_image = detection_utils.read_image(args.sample_image, format=cfg.INPUT.FORMAT) | |
# Do same preprocessing as DefaultPredictor | |
aug = T.ResizeShortestEdge( | |
[cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST | |
) | |
height, width = original_image.shape[:2] | |
image = aug.get_transform(original_image).apply_image(original_image) | |
image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) | |
inputs = {"image": image, "height": height, "width": width} | |
# Sample ready | |
sample_inputs = [inputs] | |
return sample_inputs | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Export a model for deployment.") | |
parser.add_argument( | |
"--format", | |
choices=["caffe2", "onnx", "torchscript"], | |
help="output format", | |
default="torchscript", | |
) | |
parser.add_argument( | |
"--export-method", | |
choices=["caffe2_tracing", "tracing", "scripting"], | |
help="Method to export models", | |
default="tracing", | |
) | |
parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") | |
parser.add_argument("--sample-image", default=None, type=str, help="sample image for input") | |
parser.add_argument("--run-eval", action="store_true") | |
parser.add_argument("--output", help="output directory for the converted model") | |
parser.add_argument( | |
"opts", | |
help="Modify config options using the command-line", | |
default=None, | |
nargs=argparse.REMAINDER, | |
) | |
args = parser.parse_args() | |
logger = setup_logger() | |
logger.info("Command line arguments: " + str(args)) | |
PathManager.mkdirs(args.output) | |
# Disable re-specialization on new shapes. Otherwise --run-eval will be slow | |
torch._C._jit_set_bailout_depth(1) | |
cfg = setup_cfg(args) | |
# create a torch model | |
torch_model = build_model(cfg) | |
DetectionCheckpointer(torch_model).resume_or_load(cfg.MODEL.WEIGHTS) | |
torch_model.eval() | |
# convert and save model | |
if args.export_method == "caffe2_tracing": | |
sample_inputs = get_sample_inputs(args) | |
exported_model = export_caffe2_tracing(cfg, torch_model, sample_inputs) | |
elif args.export_method == "scripting": | |
exported_model = export_scripting(torch_model) | |
elif args.export_method == "tracing": | |
sample_inputs = get_sample_inputs(args) | |
exported_model = export_tracing(torch_model, sample_inputs) | |
# run evaluation with the converted model | |
if args.run_eval: | |
assert exported_model is not None, ( | |
"Python inference is not yet implemented for " | |
f"export_method={args.export_method}, format={args.format}." | |
) | |
logger.info("Running evaluation ... this takes a long time if you export to CPU.") | |
dataset = cfg.DATASETS.TEST[0] | |
data_loader = build_detection_test_loader(cfg, dataset) | |
# NOTE: hard-coded evaluator. change to the evaluator for your dataset | |
evaluator = COCOEvaluator(dataset, output_dir=args.output) | |
metrics = inference_on_dataset(exported_model, data_loader, evaluator) | |
print_csv_format(metrics) | |
logger.info("Success.") | |