Skip to content

Commit

Permalink
fix handling dynamic axes from custon onnx config
Browse files Browse the repository at this point in the history
  • Loading branch information
ptarasiewiczNV committed Feb 22, 2023
1 parent 0e94c1a commit d701b15
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 63 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@ limitations under the License.

# Changelog

## 0.4.1
- fix: when specified use dynamic axes from custom OnnxConfig

[//]: <> (put here on external component update with short summary what change or link to changelog)

- Version of external components used during testing:
- [PyTorch 1.14.0a0+410ce96](https://github.com/pytorch/pytorch/commit/410ce96)
- [TensorFlow 2.11.0](https://github.com/tensorflow/tensorflow/releases/tag/v2.11.0)
- [TensorRT 8.5.2.2](https://docs.nvidia.com/deeplearning/tensorrt/release-notes/index.html)
- [ONNX Runtime 1.13.1](https://github.com/microsoft/onnxruntime/tree/v1.13.1)
- [Polygraphy](https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/): 0.43.1
- [GraphSurgeon](https://github.com/NVIDIA/TensorRT/tree/master/tools/onnx-graphsurgeon/): 0.4.6
- [tf2onnx v1.13.0](https://github.com/onnx/tensorflow-onnx/releases/tag/v1.13.0)
- Other component versions depend on the used framework and Triton Inference Server containers versions.
See its [support matrix](https://docs.nvidia.com/deeplearning/frameworks/support-matrix/index.html)
for a detailed summary.

## 0.4.0
- new: `optimize` method that replace `export` and perform max batch size search and improved profiling during process
- new: Introduced custom configs in `optimize` for better parametrization of export/conversion commands
Expand Down
2 changes: 1 addition & 1 deletion model_navigator/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# noqa: D100
__version__ = "0.4.0"
__version__ = "0.4.1"
30 changes: 28 additions & 2 deletions model_navigator/commands/export/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
The module provide functionality to export model to TorchScript and/or ONNX.
"""
from pathlib import Path
from typing import Any, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union

from model_navigator.api.config import JitType
from model_navigator.commands.base import Command, CommandOutput, CommandStatus
from model_navigator.commands.export import exporters
from model_navigator.exceptions import ModelNavigatorConfigurationError
from model_navigator.execution_context import ExecutionContext
from model_navigator.logger import LOGGER
from model_navigator.utils.common import parse_kwargs_to_cmd
Expand Down Expand Up @@ -140,6 +141,7 @@ def _run(
forward_kw_names: Optional[Tuple[str, ...]] = None,
model: Optional[Any] = None,
batch_dim: Optional[int] = None,
dynamic_axes: Optional[Dict[str, Union[Dict[int, str], List[int]]]] = None,
) -> CommandOutput:
"""Execute command.
Expand All @@ -154,6 +156,7 @@ def _run(
forward_kw_names: Additional arguments to override input names
model: The model that has to be exported
batch_dim: Location of batch position in shapes
dynamic_axes: Definition of model inputs dynamic axes
Returns:
CommandOutput object with status
Expand All @@ -169,6 +172,12 @@ def _run(
if model is None:
raise RuntimeError("Expected model of type torch.nn.Module. Got None instead.")

if dynamic_axes is None:
dynamic_axes = dict(**input_metadata.dynamic_axes, **output_metadata.dynamic_axes)
LOGGER.warning(f"No dynamic axes provided. Using values derived from the dataloader: {dynamic_axes}")
else:
_validate_if_dynamic_axes_aligns_with_dataloader_shapes(dynamic_axes, input_metadata, output_metadata)

model.to(target_device)

exporters.torch2onnx.get_model = lambda: model
Expand All @@ -191,7 +200,7 @@ def on_exit():
"opset": opset,
"input_names": list(input_metadata.keys()),
"output_names": list(output_metadata.keys()),
"dynamic_axes": dict(**input_metadata.dynamic_axes, **output_metadata.dynamic_axes),
"dynamic_axes": dynamic_axes,
"batch_dim": batch_dim,
"forward_kw_names": list(forward_kw_names) if forward_kw_names else None,
"target_device": target_device,
Expand All @@ -201,3 +210,20 @@ def on_exit():
context.execute_local_runtime_script(exporters.torch2onnx.__file__, exporters.torch2onnx.export, args)

return CommandOutput(status=CommandStatus.OK)


def _validate_if_dynamic_axes_aligns_with_dataloader_shapes(
dynamic_axes: Dict[str, Union[Dict[int, str], List[int]]],
input_metadata: TensorMetadata,
output_metadata: TensorMetadata,
):
for name, axes in dynamic_axes.items():
axes = list(axes)
tensor_spec = input_metadata.get(name, None) or output_metadata.get(name, None)
if tensor_spec is None:
raise ModelNavigatorConfigurationError(f"Dynamic axis {axes} is specified for unknown input {name}.")
for ax, d in enumerate(tensor_spec.shape):
if d == -1 and ax not in axes:
raise ModelNavigatorConfigurationError(
f"In tensor `{name}` axis `{ax}` is not set as dynamic axes but is dynamic in the dataloader."
)
28 changes: 0 additions & 28 deletions model_navigator/commands/infer_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,23 +98,6 @@ def _get_metadata_from_axes_shapes(axes_shapes, batch_dim, dtypes):
return metadata


def _update_user_dynamic_axes_inplace(dynamic_axes, input_metadata):
for name, axes in dynamic_axes.items():
axes = list(axes)
tensor_spec = input_metadata.get(name, None)
if tensor_spec is None:
return
shape = list(tensor_spec.shape)
for ax in axes:
shape[ax] = -1 # update
input_metadata[name] = TensorSpec(name, tuple(shape), tensor_spec.dtype)
for ax, d in enumerate(tensor_spec.shape):
if d == -1 and ax not in axes: # verify
raise ValueError(
f"In tensor `{name}` axis `{ax}` is not set as dynamic axes but is dynamic in the dataloader."
)


class InferInputMetadata(Command, is_required=True):
"""Command to collect model inputs metadata."""

Expand All @@ -125,7 +108,6 @@ def _run(
dataloader: SizedDataLoader,
_input_names: Optional[Tuple[str, ...]] = None,
batch_dim: Optional[int] = None,
dynamic_axes: Optional[Dict[str, Union[Dict[int, str], List[int]]]] = None,
) -> CommandOutput:
"""Execute the InferInputMetadata command.
Expand All @@ -134,7 +116,6 @@ def _run(
model: A model object or path to file
dataloader: Dataloader for providing samples
_input_names: Name of model inputs
dynamic_axes: Definition of model inputs dynamic axes
batch_dim: Location of batch dimension in data samples
Returns:
Expand All @@ -155,11 +136,6 @@ def _run(
dataloader_trt_profile = _get_trt_profile_from_axes_shapes(axes_shapes, batch_dim)

input_metadata = _get_metadata_from_axes_shapes(axes_shapes, batch_dim, input_dtypes)
if dynamic_axes is None:
LOGGER.warning(f"No dynamic axes provided. Using values derived from the dataloader: {input_metadata}")
else:
_update_user_dynamic_axes_inplace(dynamic_axes, input_metadata)

return CommandOutput(
status=CommandStatus.OK,
output={
Expand Down Expand Up @@ -261,10 +237,6 @@ def _run(
)

output_metadata = _get_metadata_from_axes_shapes(axes_shapes, batch_dim, output_dtypes)
if dynamic_axes is None:
LOGGER.warning(f"No dynamic axes provided. Using values derived from the dataloader: {output_metadata}")
else:
_update_user_dynamic_axes_inplace(dynamic_axes, output_metadata)

return CommandOutput(
status=CommandStatus.OK,
Expand Down
3 changes: 2 additions & 1 deletion model_navigator/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def to_numpy(a):
class TensorMetadata(Dict[str, TensorSpec]):
"""Metadata for inputs/outputs tensors."""

def add(self, name: str, shape: Sequence[int], dtype: Union[np.dtype, Type[np.dtype]]) -> None:
def add(self, name: str, shape: Sequence[int], dtype: Union[np.dtype, Type[np.dtype]]) -> "TensorMetadata":
"""Add new item to metadata.
Args:
Expand All @@ -195,6 +195,7 @@ def add(self, name: str, shape: Sequence[int], dtype: Union[np.dtype, Type[np.dt
dtype: Type of tensor data
"""
self[name] = TensorSpec(name, tuple(shape), np.dtype(dtype))
return self

@classmethod
def from_json(cls, data: List[Dict]) -> "TensorMetadata":
Expand Down
31 changes: 0 additions & 31 deletions tests/unit/base/test_infer_metadata_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

import numpy
import pytest

from model_navigator.api.config import TensorRTProfile
from model_navigator.commands.infer_metadata import (
_extract_max_batch_size,
_get_metadata_from_axes_shapes,
_get_trt_profile_from_axes_shapes,
_update_user_dynamic_axes_inplace,
)
from model_navigator.utils.tensor import TensorSpec

Expand Down Expand Up @@ -77,32 +75,3 @@ def test_get_metadata_return_correct_data_from_axes_shapes_when_with_valid_shape
metadata = _get_metadata_from_axes_shapes(axes_shapes=axes_shapes, batch_dim=batch_dim, dtypes=dtypes)

assert metadata == expected_metadata


def test_update_user_dynamic_axes_inplace_return_valid_metadata_when_tensors_with_dynamic_axis_passed():
input_name = "input_0"
dtype_name = "float64"
dynamic_axes = {input_name: (0,)}
input_metadata = {
input_name: TensorSpec(name=input_name, shape=(5, 224, 224, 3), dtype=numpy.dtype(dtype_name), optional=False)
}
expected_input_metadata = {
input_name: TensorSpec(name=input_name, shape=(-1, 224, 224, 3), dtype=numpy.dtype(dtype_name), optional=False)
}

_update_user_dynamic_axes_inplace(dynamic_axes=dynamic_axes, input_metadata=input_metadata)

assert input_metadata == expected_input_metadata


def test_update_user_dynamic_axes_inplace_raise_exception_when_metadata_and_dataloader_dynamic_axes_missmatch():
input_name = "input_0"
dtype_name = "float64"
dynamic_axes = {input_name: (0,)}
input_metadata = {
input_name: TensorSpec(name=input_name, shape=(-1, -1, 224, 3), dtype=numpy.dtype(dtype_name), optional=False)
}

with pytest.raises(ValueError):
# ValueError: In tensor `input_0` axis `1` is not set as dynamic axes but is dynamic in the dataloader.
_update_user_dynamic_axes_inplace(dynamic_axes=dynamic_axes, input_metadata=input_metadata)
45 changes: 45 additions & 0 deletions tests/unit/torch/test_torch_export_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest

from model_navigator.commands.export.torch import _validate_if_dynamic_axes_aligns_with_dataloader_shapes
from model_navigator.exceptions import ModelNavigatorConfigurationError
from model_navigator.utils.tensor import TensorMetadata


def test_validate_if_dynamic_axes_aligns_with_dataloader_shapes_raises_error_when_unknown_axes():
dynamic_axes = {"unknown_input": [0]}
input_metadata = TensorMetadata().add("input", [1, 2], np.float32)
output_metadata = TensorMetadata().add("output", [1, 2], np.float32)

with pytest.raises(ModelNavigatorConfigurationError):
_validate_if_dynamic_axes_aligns_with_dataloader_shapes(dynamic_axes, input_metadata, output_metadata)


def test_validate_if_dynamic_axes_aligns_with_dataloader_shapes_raises_error_when_missing_dataloader_dynamic_ax():
dynamic_axes = {"input": [0]}
input_metadata = TensorMetadata().add("input", [-1, -1], np.float32)
output_metadata = TensorMetadata().add("output", [1, 2], np.float32)

with pytest.raises(ModelNavigatorConfigurationError):
_validate_if_dynamic_axes_aligns_with_dataloader_shapes(dynamic_axes, input_metadata, output_metadata)


def test_validate_if_dynamic_axes_aligns_with_dataloader_shapes_raises_no_errors_when_dynamic_axes_aligns():
dynamic_axes = {"input": [0, 1]}
input_metadata = TensorMetadata().add("input", [-1, -1], np.float32)
output_metadata = TensorMetadata().add("output", [1, 2], np.float32)

_validate_if_dynamic_axes_aligns_with_dataloader_shapes(dynamic_axes, input_metadata, output_metadata)

0 comments on commit d701b15

Please sign in to comment.