From 60205f80f300a7000373ea2dd6ff73574bf16978 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Mon, 4 May 2020 10:46:59 -0700 Subject: [PATCH] allow converter to dump torchscript; support torchscript GPU inference Reviewed By: rbgirshick Differential Revision: D21364163 fbshipit-source-id: 6d83968b483f91df976939d8682031a5c60dd271 --- detectron2/export/api.py | 15 ++++--- detectron2/export/c10.py | 3 -- tools/deploy/CMakeLists.txt | 7 +++ tools/deploy/caffe2_converter.py | 44 ++++++++++++++++--- tools/deploy/torchscript_traced_mask_rcnn.cpp | 25 +++++------ 5 files changed, 66 insertions(+), 28 deletions(-) diff --git a/detectron2/export/api.py b/detectron2/export/api.py index f71063a4b5..a7600714e1 100644 --- a/detectron2/export/api.py +++ b/detectron2/export/api.py @@ -15,7 +15,13 @@ from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph -__all__ = ["add_export_config", "export_caffe2_model", "Caffe2Model", "export_onnx_model"] +__all__ = [ + "add_export_config", + "export_caffe2_model", + "Caffe2Model", + "export_onnx_model", + "Caffe2Tracer", +] def add_export_config(cfg): @@ -47,7 +53,8 @@ class Caffe2Tracer: 3. complicated pre/post processing This class provides a traceable version of a detectron2 model by: - 1. Rewrite parts of the model using ops in caffe2 + 1. Rewrite parts of the model using ops in caffe2. Note that some ops do + not have GPU implementation. 2. Define the inputs "after pre-processing" as inputs to the model 3. Remove post-processing and produce raw layer outputs @@ -59,8 +66,6 @@ class Caffe2Tracer: model to different deployment formats. The class currently only supports models using builtin meta architectures. - - Experimental. Don't use. """ def __init__(self, cfg, model, inputs): @@ -127,7 +132,7 @@ def export_torchscript(self): logger = logging.getLogger(__name__) logger.info("Tracing the model with torch.jit.trace ...") with torch.no_grad(): - return torch.jit.trace(model, (inputs,)) + return torch.jit.trace(model, (inputs,), optimize=True) def export_caffe2_model(cfg, model, inputs): diff --git a/detectron2/export/c10.py b/detectron2/export/c10.py index 66085b01c9..6e3cbe3ce9 100644 --- a/detectron2/export/c10.py +++ b/detectron2/export/c10.py @@ -164,9 +164,6 @@ def forward(self, images, features, gt_instances=None): features = [features[f] for f in self.in_features] objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features) - # TODO is the needed? - # objectness_logits_pred = [t.sigmoid() for t in objectness_logits_pred] - assert isinstance(images, ImageList) if self.tensor_mode: im_info = images.image_sizes diff --git a/tools/deploy/CMakeLists.txt b/tools/deploy/CMakeLists.txt index 4ac50f005b..0c3ca7a33c 100644 --- a/tools/deploy/CMakeLists.txt +++ b/tools/deploy/CMakeLists.txt @@ -12,3 +12,10 @@ target_link_libraries( caffe2_mask_rcnn "${TORCH_LIBRARIES}" gflags glog ${OpenCV_LIBS}) set_property(TARGET caffe2_mask_rcnn PROPERTY CXX_STANDARD 14) + + +add_executable(torchscript_traced_mask_rcnn torchscript_traced_mask_rcnn.cpp) +target_link_libraries( + torchscript_traced_mask_rcnn + "${TORCH_LIBRARIES}" ${OpenCV_LIBS}) +set_property(TARGET torchscript_traced_mask_rcnn PROPERTY CXX_STANDARD 14) diff --git a/tools/deploy/caffe2_converter.py b/tools/deploy/caffe2_converter.py index d86c6382a4..08feb69fba 100755 --- a/tools/deploy/caffe2_converter.py +++ b/tools/deploy/caffe2_converter.py @@ -2,13 +2,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import argparse import os +import onnx import torch from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import build_detection_test_loader from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format -from detectron2.export import add_export_config, export_caffe2_model +from detectron2.export import Caffe2Tracer, add_export_config from detectron2.modeling import build_model from detectron2.utils.logger import setup_logger @@ -28,10 +29,16 @@ def setup_cfg(args): if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Convert a model to Caffe2") + parser = argparse.ArgumentParser(description="Convert a model using caffe2 tracing.") + parser.add_argument( + "--format", + choices=["caffe2", "onnx", "torchscript"], + help="output format", + default="caffe2", + ) parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") parser.add_argument("--run-eval", action="store_true") - parser.add_argument("--output", help="output directory for the converted caffe2 model") + parser.add_argument("--output", help="output directory for the converted model") parser.add_argument( "opts", help="Modify config options using the command-line", @@ -41,6 +48,7 @@ def setup_cfg(args): args = parser.parse_args() logger = setup_logger() logger.info("Command line arguments: " + str(args)) + os.makedirs(args.output, exist_ok=True) cfg = setup_cfg(args) @@ -53,13 +61,35 @@ def setup_cfg(args): first_batch = next(iter(data_loader)) # convert and save caffe2 model - caffe2_model = export_caffe2_model(cfg, torch_model, first_batch) - caffe2_model.save_protobuf(args.output) - # draw the caffe2 graph - caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=first_batch) + tracer = Caffe2Tracer(cfg, torch_model, first_batch) + if args.format == "caffe2": + caffe2_model = tracer.export_caffe2() + caffe2_model.save_protobuf(args.output) + # draw the caffe2 graph + caffe2_model.save_graph(os.path.join(args.output, "model.svg"), inputs=first_batch) + elif args.format == "onnx": + onnx_model = tracer.export_onnx() + onnx.save(onnx_model, os.path.join(args.output, "model.onnx")) + elif args.format == "torchscript": + script_model = tracer.export_torchscript() + script_model.save(os.path.join(args.output, "model.ts")) + + # Recursively print IR of all modules + with open(os.path.join(args.output, "model_ts_IR.txt"), "w") as f: + try: + f.write(script_model._actual_script_module._c.dump_to_str(True, False, False)) + except AttributeError: + pass + # Print IR of the entire graph (all submodules inlined) + with open(os.path.join(args.output, "model_ts_IR_inlined.txt"), "w") as f: + f.write(str(script_model.inlined_graph)) + # Print the model structure in pytorch style + with open(os.path.join(args.output, "model.txt"), "w") as f: + f.write(str(script_model)) # run evaluation with the converted model if args.run_eval: + assert args.format == "caffe2", "Python inference in other format is not yet supported." dataset = cfg.DATASETS.TEST[0] data_loader = build_detection_test_loader(cfg, dataset) # NOTE: hard-coded evaluator. change to the evaluator for your dataset diff --git a/tools/deploy/torchscript_traced_mask_rcnn.cpp b/tools/deploy/torchscript_traced_mask_rcnn.cpp index 1d8340b3d1..82fbdb052f 100644 --- a/tools/deploy/torchscript_traced_mask_rcnn.cpp +++ b/tools/deploy/torchscript_traced_mask_rcnn.cpp @@ -9,7 +9,7 @@ using namespace std; -// Experimental. Don't use. +// experimental. don't use int main(int argc, const char* argv[]) { if (argc != 3) { return 1; @@ -19,26 +19,25 @@ int main(int argc, const char* argv[]) { torch::autograd::AutoGradMode guard(false); auto module = torch::jit::load(argv[1]); + assert(module.buffers().size() > 0); + // Assume that the entire model is on the same device. + // We just put input to this device. + auto device = (*begin(module.buffers())).device(); + cv::Mat input_img = cv::imread(image_file, cv::IMREAD_COLOR); const int height = input_img.rows; const int width = input_img.cols; // FPN models require divisibility of 32 assert(height % 32 == 0 && width % 32 == 0); - const int batch = 1; const int channels = 3; - auto input = torch::empty({1, channels, height, width}); - float* ptr = input.data_ptr(); - // HWC to CHW - for (int c = 0; c < 3; ++c) { - for (int i = 0; i < height * width; ++i) { - ptr[c * height * width + i] = - static_cast(input_img.data[3 * i + c]); - } - } + auto input = torch::from_blob( + input_img.data, {1, height, width, channels}, torch::kUInt8); + // NHWC to NCHW + input = input.to(device, torch::kFloat).permute({0, 3, 1, 2}).contiguous(); - float im_info_data[] = {height * 1.0f, width * 1.0f, 1.0f}; - auto im_info = torch::from_blob(im_info_data, {1, 3}); + std::array im_info_data{height * 1.0f, width * 1.0f, 1.0f}; + auto im_info = torch::from_blob(im_info_data.data(), {1, 3}).to(device); // run the network auto output = module.forward({std::make_tuple(input, im_info)});