Skip to content

Commit

Permalink
feat: Add support for TorchTensorRTModule in Dynamo
Browse files Browse the repository at this point in the history
- Rename `TRTModuleNext` to `TorchTensorRTModule` across the repository,
and move the source directory to `dynamo`
- Update imports across the repository
- Refactor `convert_module` code to support conversion to a
`TorchTensorRTModule`
- Add tests for `TorchTensorRTModule` functionality in Dynamo
  • Loading branch information
gs-olive committed Jun 8, 2023
1 parent 81d488a commit 8ec3f01
Show file tree
Hide file tree
Showing 18 changed files with 246 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/fx/fx2trt_example_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch_tensorrt.fx.tracer.acc_tracer.acc_tracer as acc_tracer
from torch_tensorrt.fx import InputTensorSpec, TRTInterpreter
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting
from torch_tensorrt import TRTModuleNext as TRTModule, Device
from torch_tensorrt import TorchTensorRTModule as TRTModule, Device

# The purpose of this example is to demonstrate the overall flow of lowering a PyTorch
# model to TensorRT via FX with existing FX based tooling. The general lowering flow
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,13 @@ def _find_lib(name, paths):
from torch_tensorrt import logging
from torch_tensorrt._Input import Input
from torch_tensorrt._Device import Device
from torch_tensorrt._TRTModuleNext import TRTModuleNext

from torch_tensorrt import fx

if version.parse(torch.__version__) >= version.parse("2.dev"):
from torch_tensorrt import dynamo
from torch_tensorrt.dynamo import backend
from torch_tensorrt.dynamo import TorchTensorRTModule


def _register_with_torch():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
from operator import truediv
from typing import Any, List, Sequence, Tuple
from typing import Any, List, Tuple

import torch
from torch_tensorrt import _C
Expand All @@ -9,8 +8,8 @@
logger = logging.getLogger(__name__)


class TRTModuleNext(torch.nn.Module):
"""TRTModuleNext is a PyTorch module which encompasses an arbitrary TensorRT Engine.
class TorchTensorRTModule(torch.nn.Module):
"""TorchTensorRTModule is a PyTorch module which encompasses an arbitrary TensorRT Engine.
This module is backed by the Torch-TensorRT runtime and is fully compatibile with both
FX / Python deployments (just ``import torch_tensorrt`` as part of the application) as
Expand All @@ -20,7 +19,7 @@ class TRTModuleNext(torch.nn.Module):
The forward function is simpily forward(*args: torch.Tensor) -> Tuple[torch.Tensor] where
the internal implementation is ``return Tuple(torch.ops.tensorrt.execute_engine(list(inputs), self.engine))``
> Note: TRTModuleNext only supports engines built with explict batch
> Note: TorchTensorRTModule only supports engines built with explict batch
Attributes:
name (str): Name of module (for easier debugging)
Expand All @@ -37,7 +36,7 @@ def __init__(
output_binding_names: List[str] = [],
target_device: Device = Device._current_device(),
):
"""__init__ method for torch_tensorrt.TRTModuleNext
"""__init__ method for torch_tensorrt.TorchTensorRTModule
Takes a name, target device, serialized TensorRT engine, and binding names / order and constructs
a PyTorch ``torch.nn.Module`` around it.
Expand Down Expand Up @@ -71,9 +70,9 @@ def __init__(
"""
logger.warning(
"TRTModuleNext should be considered experimental stability, APIs are subject to change. Note: TRTModuleNext only supports engines built with explict batch"
"TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch"
)
super(TRTModuleNext, self).__init__()
super(TorchTensorRTModule, self).__init__()

if not isinstance(serialized_engine, bytearray):
ValueError("Expected serialized engine as bytearray")
Expand All @@ -89,8 +88,8 @@ def __init__(
self.name + "_engine" if self.name != "" else "tensorrt_engine",
target_device._to_serialized_rt_device(),
serialized_engine,
TRTModuleNext._pack_binding_names(self.input_binding_names),
TRTModuleNext._pack_binding_names(self.output_binding_names),
TorchTensorRTModule._pack_binding_names(self.input_binding_names),
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
]
)
else:
Expand Down Expand Up @@ -154,7 +153,7 @@ def is_non_tensor(i: Tuple[Any, bool]) -> bool:

non_tensors = [i[0] for i in filter(zip(inputs, types), is_non_tensor)]
raise RuntimeError(
f"TRTModuleNext expects a flattened list of tensors as input, found non tensors: {non_tensors}"
f"TorchTensorRTModule expects a flattened list of tensors as input, found non tensors: {non_tensors}"
)

outputs = torch.ops.tensorrt.execute_engine(list(inputs), self.engine)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from torch_tensorrt.dynamo import fx_ts_compat
from .backend import compile
from ._TorchTensorRTModule import TorchTensorRTModule
13 changes: 13 additions & 0 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
USE_EXPERIMENTAL_RT,
)


Expand Down Expand Up @@ -45,6 +46,7 @@ def compile(
min_block_size=MIN_BLOCK_SIZE,
torch_executed_ops=[],
torch_executed_modules=[],
use_experimental_rt=USE_EXPERIMENTAL_RT,
**kwargs,
):
if debug:
Expand All @@ -57,6 +59,13 @@ def compile(
+ "torch_executed_ops, pass_through_build_failures}"
)

if "use_experimental_fx_rt" in kwargs:
logger.info(
"Detected option 'use_experimental_fx_rt' in kwargs, "
+ "overwriting the 'use_experimental_rt' argument."
)
use_experimental_rt = kwargs["use_experimental_fx_rt"]

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

Expand Down Expand Up @@ -86,6 +95,7 @@ def compile(
workspace_size=workspace_size,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
use_experimental_rt=use_experimental_rt,
**kwargs,
)

Expand All @@ -109,6 +119,7 @@ def create_backend(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Sequence[str] = set(),
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -120,6 +131,7 @@ def create_backend(
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
Returns:
Backend for torch.compile
"""
Expand All @@ -133,6 +145,7 @@ def create_backend(
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
pass_through_build_failures=pass_through_build_failures,
use_experimental_rt=use_experimental_rt,
)

return partial(
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
WORKSPACE_SIZE = 0
MIN_BLOCK_SIZE = 5
PASS_THROUGH_BUILD_FAILURES = False
USE_EXPERIMENTAL_RT = False
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
WORKSPACE_SIZE,
MIN_BLOCK_SIZE,
PASS_THROUGH_BUILD_FAILURES,
USE_EXPERIMENTAL_RT,
)


Expand All @@ -19,3 +20,4 @@ class CompilationSettings:
min_block_size: int = MIN_BLOCK_SIZE
torch_executed_ops: Sequence[str] = field(default_factory=set)
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES
use_experimental_rt: bool = USE_EXPERIMENTAL_RT
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def _compile_module(
submodule,
submodule_inputs,
settings=settings,
name=name,
)

# Replace FX Module with TRT Module
Expand Down
28 changes: 21 additions & 7 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Sequence, Union
import torch
import io
from torch_tensorrt.fx.trt_module import TRTModule
from torch_tensorrt import TRTModuleNext
from torch_tensorrt.dynamo import TorchTensorRTModule
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from torch_tensorrt.dynamo.fx_ts_compat.fx2trt import (
InputTensorSpec,
Expand All @@ -15,12 +16,14 @@ def convert_module(
module: torch.fx.GraphModule,
inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
) -> Union[TRTModuleNext, TRTModule]:
name: str = "",
) -> Union[TorchTensorRTModule, TRTModule]:
"""Convert an FX module to a TRT module
Args:
module: FX GraphModule to convert
inputs: Sequence of Tensors representing inputs to the module
settings: Compilation settings
name: TRT engine name
Returns:
TRTModule or TRTModuleNext
"""
Expand All @@ -41,8 +44,19 @@ def convert_module(
),
)

return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)
if settings.use_experimental_rt:
with io.BytesIO() as engine_bytes:
engine_bytes.write(interpreter_result.engine.serialize())
engine_str = engine_bytes.getvalue()
return TorchTensorRTModule(
serialized_engine=engine_str,
name=name,
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
)
else:
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)
Loading

0 comments on commit 8ec3f01

Please sign in to comment.