Skip to content

Commit

Permalink
Release v0.7.1
Browse files Browse the repository at this point in the history
  • Loading branch information
ptarasiewiczNV committed Aug 21, 2023
1 parent 52bf655 commit 3ee5574
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 44 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ limitations under the License.

# Changelog

## 0.7.1
- fix: gather onnx input names based on model's forward signature
- fix: do not run TensorRT max batch size search when max batch size is None
- fix: use pytree metadata to flatten torch complex outputs

- Version of external components used during testing:
- [PyTorch 2.1.0a0+b5021ba](https://github.com/pytorch/pytorch/commit/b5021ba9)
- [TensorFlow 2.12.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.12.0)
- [TensorRT 8.6.1](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html)
- [ONNX Runtime 1.15.1](https://github.com/microsoft/onnxruntime/tree/v1.15.1)
- [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/): 0.47.1
- [GraphSurgeon](https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon/): 0.3.27
- [tf2onnx v1.14.0](https://github.com/onnx/tensorflow-onnx/releases/tag/v1.14.0)
- Other component versions depend on the used framework containers versions.
See its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
for a detailed summary.

## 0.7.0
- new: Inplace Optimize feature - optimize models directly in the Python code
- new: Non-tensor inputs and outputs support
Expand Down
2 changes: 1 addition & 1 deletion model_navigator/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# noqa: D100
__version__ = "0.7.0"
__version__ = "0.7.1"
24 changes: 19 additions & 5 deletions model_navigator/commands/convert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ def _execute_conversion(
)

if run_search:
assert dataloader_max_batch_size is not None
assert device_max_batch_size is not None

conversion_max_batch_size = (
cls._execute_conversion_with_max_batch_size_search( # TODO what is the best value?
convert_func=convert_func,
Expand All @@ -68,14 +71,25 @@ def _execute_conversion(
dataloader_max_batch_size=dataloader_max_batch_size,
)
)
LOGGER.info(f"Converted with maximal batch size: {conversion_max_batch_size}.")
else:
LOGGER.info("Search for maximal batch size disable. Execute single conversion.")
conversion_max_batch_size = dataloader_max_batch_size
convert_func(get_args())
conversion_max_batch_size = cls._execute_single_conversion(
convert_func=convert_func, get_args=get_args, max_batch_size=dataloader_max_batch_size
)
LOGGER.info(f"Converted with maximal batch size: {conversion_max_batch_size}.")

return conversion_max_batch_size

@classmethod
def _execute_single_conversion(
cls,
convert_func: Callable,
get_args: Callable,
max_batch_size: int,
):
LOGGER.info("Search for maximal batch size disable. Execute single conversion.")
convert_func(get_args())
return max_batch_size

@classmethod
def _execute_conversion_with_max_batch_size_search(
cls,
Expand Down Expand Up @@ -213,7 +227,7 @@ def _run_search(
LOGGER.info("`batch_dim` is None. Model does not support batching.")
return False

if not cls._is_valid_batch_size(dataloader_batch_size) and not cls._is_valid_batch_size(device_max_batch_size):
if not cls._is_valid_batch_size(dataloader_batch_size) or not cls._is_valid_batch_size(device_max_batch_size):
LOGGER.info(
"Dataloader or device max batch size is invalid.\n"
"Provided values:\n"
Expand Down
18 changes: 17 additions & 1 deletion model_navigator/commands/export/exporters/torch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Export Torch model to ONNX model."""

import inspect
import pathlib
from typing import Any, Dict, List, Optional

Expand Down Expand Up @@ -71,7 +72,22 @@ def export(

dummy_input = {n: torch.from_numpy(val).to(target_device) for n, val in profiling_sample.items()}
dummy_input = input_metadata.unflatten_sample(dummy_input)
input_names = list(input_metadata.keys())

forward_argspec = inspect.getfullargspec(model.forward)
forward_args = forward_argspec.args[1:]

args_mapping, kwargs_mapping = input_metadata.pytree_metadata.get_names_mapping()

for argname in kwargs_mapping:
assert argname in forward_args, f"Argument {argname} is not in forward argspec."

input_names = []
for args_names in args_mapping:
input_names.extend(args_names)

for argname in forward_args:
if argname in kwargs_mapping:
input_names.extend(kwargs_mapping[argname])

exported_model_path = pathlib.Path(exported_model_path)
if not exported_model_path.is_absolute():
Expand Down
26 changes: 25 additions & 1 deletion model_navigator/core/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def unflatten_sample(self, sample: Dict[str, Any], wrap_input: bool = False) ->
If wrap_input is True, then single tensor will be wrapped in tuple.
"""
unflatten_sample = self._unflatten_sample(sample, self._metadata)
if wrap_input and isinstance(self._metadata, (str, dict)):
if wrap_input and isinstance(self._metadata, (str, Mapping)):
unflatten_sample = (unflatten_sample,)
return unflatten_sample

Expand All @@ -318,6 +318,30 @@ def is_compatible_with(self, sample: Any) -> bool:
"""
return self._is_compatible_with(self._metadata, sample)

def get_names_mapping(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
"""Get mapping of PyTree metadata to names."""
metadata = self._metadata
if isinstance(metadata, (str, Mapping)):
metadata = (metadata,)

if isinstance(metadata[-1], Mapping):
args, kwargs = metadata[:-1], metadata[-1]
else:
args, kwargs = metadata, {}

args_mapping, kwargs_mapping = [], {}
for arg in args:
flattened = {}
self._flatten_sample(arg, arg, flattened, include_constants=True)
args_mapping.append(list(flattened.keys()))

for key, arg in kwargs.items():
flattened = {}
self._flatten_sample(arg, arg, flattened, include_constants=True)
kwargs_mapping[key] = list(flattened.keys())

return args_mapping, kwargs_mapping

def _is_compatible_with(self, metadata, sample):
if isinstance(metadata, str):
return is_tensor(sample, self.tensor_type)
Expand Down
18 changes: 2 additions & 16 deletions model_navigator/runners/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Torch runners."""
from collections import OrderedDict
from typing import List, Mapping
from typing import List

from model_navigator.api.config import Format, TensorType
from model_navigator.core.tensor import get_tensor_type
Expand All @@ -23,7 +22,6 @@
from model_navigator.runners.registry import register_runner
from model_navigator.utils import module
from model_navigator.utils.common import numpy_to_torch_dtype
from model_navigator.utils.dataloader import get_default_output_names

torch = module.lazy_import("torch")

Expand Down Expand Up @@ -63,19 +61,7 @@ def infer_impl(self, feed_dict, return_raw_outputs=False):
if return_raw_outputs:
return outputs

if torch.is_tensor(outputs):
outputs = (outputs,)
if isinstance(outputs, Mapping):
outputs = outputs.values()

out_dict = OrderedDict()
if self.output_metadata:
output_names = self.output_metadata.keys()
else:
output_names = outputs.keys() if isinstance(outputs, Mapping) else get_default_output_names(len(outputs))

for name, output in zip(output_names, outputs):
out_dict[name] = output
out_dict = self.output_metadata.flatten_sample(outputs)
out_dict = self._prepare_outputs(out_dict)

return out_dict
Expand Down
30 changes: 10 additions & 20 deletions tests/unit/base/test_commands_convert_onnx_onnx2trt.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_run_execute_conversion_when_dataloader_and_device_max_batch_size_is_inv
assert ConvertONNX2TRT._execute_conversion.called is True # pytype: disable=attribute-error


def test_run_execute_conversion_with_max_batch_size_search_when_dataloader_max_batch_size_provided(mocker):
def test_run_execute_single_conversion_when_only_dataloader_max_batch_size_provided(mocker):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
workspace = tmpdir / "navigator_workspace"
Expand All @@ -161,11 +161,9 @@ def test_run_execute_conversion_with_max_batch_size_search_when_dataloader_max_b
output_model_path = workspace / "trt-fp16" / "model.plan"
output_model_path.parent.mkdir(parents=True)

with mocker.patch.object(
ConvertONNX2TRT, "_execute_conversion_with_max_batch_size_search", return_value=3
), mocker.patch.object(ConvertONNX2TRT, "_get_onnx_input_metadata"), mocker.patch(
"model_navigator.utils.devices.get_available_gpus", return_value=[0]
):
with mocker.patch.object(ConvertONNX2TRT, "_execute_single_conversion", return_value=3), mocker.patch.object(
ConvertONNX2TRT, "_get_onnx_input_metadata"
), mocker.patch("model_navigator.utils.devices.get_available_gpus", return_value=[0]):
result = ConvertONNX2TRT().run(
workspace=Workspace(workspace),
parent_path=input_model_path,
Expand All @@ -186,13 +184,10 @@ def test_run_execute_conversion_with_max_batch_size_search_when_dataloader_max_b

assert result is not None
assert result.status == CommandStatus.OK
assert (
ConvertONNX2TRT._execute_conversion_with_max_batch_size_search.called
is True # pytype: disable=attribute-error
)
assert ConvertONNX2TRT._execute_single_conversion.called is True # pytype: disable=attribute-error


def test_run_execute_conversion_with_max_batch_size_search_when_device_max_batch_size_provided(mocker):
def test_run_execute_single_conversion_when_only_device_max_batch_size_provided(mocker):
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = pathlib.Path(tmpdir)
workspace = tmpdir / "navigator_workspace"
Expand All @@ -204,11 +199,9 @@ def test_run_execute_conversion_with_max_batch_size_search_when_device_max_batch
output_model_path = workspace / "trt-fp16" / "model.plan"
output_model_path.parent.mkdir(parents=True)

with mocker.patch.object(
ConvertONNX2TRT, "_execute_conversion_with_max_batch_size_search", return_value=3
), mocker.patch.object(ConvertONNX2TRT, "_get_onnx_input_metadata"), mocker.patch(
"model_navigator.utils.devices.get_available_gpus", return_value=[0]
):
with mocker.patch.object(ConvertONNX2TRT, "_execute_single_conversion", return_value=3), mocker.patch.object(
ConvertONNX2TRT, "_get_onnx_input_metadata"
), mocker.patch("model_navigator.utils.devices.get_available_gpus", return_value=[0]):
result = ConvertONNX2TRT().run(
workspace=Workspace(workspace),
parent_path=input_model_path,
Expand All @@ -229,10 +222,7 @@ def test_run_execute_conversion_with_max_batch_size_search_when_device_max_batch

assert result is not None
assert result.status == CommandStatus.OK
assert (
ConvertONNX2TRT._execute_conversion_with_max_batch_size_search.called
is True # pytype: disable=attribute-error
)
assert ConvertONNX2TRT._execute_single_conversion.called is True # pytype: disable=attribute-error


def test_run_execute_conversion_with_max_batch_size_search_when_both_max_batch_size_provided(mocker):
Expand Down

0 comments on commit 3ee5574

Please sign in to comment.