Skip to content

Commit

Permalink
feat: Automatically detect C++ dependency presence
Browse files Browse the repository at this point in the history
- Default automatically to test for presence of C++ dependency and use
appropriate runtime if not specified by the user
  • Loading branch information
gs-olive committed Jul 5, 2023
1 parent e2594b6 commit f31e528
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 54 deletions.
3 changes: 0 additions & 3 deletions py/torch_tensorrt/dynamo/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@ def __init__(
)
"""
logger.warning(
"TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch"
)
super(TorchTensorRTModule, self).__init__()

if not isinstance(serialized_engine, bytearray):
Expand Down
20 changes: 9 additions & 11 deletions py/torch_tensorrt/dynamo/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
USE_PYTHON_RUNTIME,
)


Expand Down Expand Up @@ -52,7 +52,7 @@ def compile(
max_aux_streams=MAX_AUX_STREAMS,
version_compatible=VERSION_COMPATIBLE,
optimization_level=OPTIMIZATION_LEVEL,
use_experimental_rt=USE_EXPERIMENTAL_RT,
use_python_runtime=USE_PYTHON_RUNTIME,
**kwargs,
):
if debug:
Expand All @@ -65,11 +65,6 @@ def compile(
+ "torch_executed_ops, pass_through_build_failures}"
)

if "use_experimental_fx_rt" in kwargs:
use_experimental_rt = kwargs["use_experimental_fx_rt"]

logger.info(f"Using {'C++' if use_experimental_rt else 'Python'} TRT Runtime")

if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]

Expand Down Expand Up @@ -107,7 +102,7 @@ def compile(
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
use_python_runtime=use_python_runtime,
**kwargs,
)

Expand All @@ -134,7 +129,7 @@ def create_backend(
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
version_compatible: bool = VERSION_COMPATIBLE,
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_experimental_rt: bool = USE_EXPERIMENTAL_RT,
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
**kwargs,
):
"""Create torch.compile backend given specified arguments
Expand All @@ -150,7 +145,9 @@ def create_backend(
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
Returns:
Backend for torch.compile
"""
Expand All @@ -165,5 +162,6 @@ def create_backend(
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
use_experimental_rt=use_experimental_rt,
use_python_runtime=use_python_runtime,
**kwargs,
)
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
MAX_AUX_STREAMS = None
VERSION_COMPATIBLE = False
OPTIMIZATION_LEVEL = None
USE_EXPERIMENTAL_RT = False
USE_PYTHON_RUNTIME = None
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/backend/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
MAX_AUX_STREAMS,
VERSION_COMPATIBLE,
OPTIMIZATION_LEVEL,
USE_EXPERIMENTAL_RT,
USE_PYTHON_RUNTIME,
)


@dataclass(frozen=True)
@dataclass
class CompilationSettings:
precision: LowerPrecision = PRECISION
debug: bool = DEBUG
Expand All @@ -26,4 +26,4 @@ class CompilationSettings:
max_aux_streams: Optional[int] = MAX_AUX_STREAMS
version_compatible: bool = VERSION_COMPATIBLE
optimization_level: Optional[int] = OPTIMIZATION_LEVEL
use_experimental_rt: bool = USE_EXPERIMENTAL_RT
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
15 changes: 8 additions & 7 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,14 @@ def convert_module(
optimization_level=settings.optimization_level,
)

if settings.use_experimental_rt:
if settings.use_python_runtime:
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)

else:
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule

with io.BytesIO() as engine_bytes:
Expand All @@ -67,9 +74,3 @@ def convert_module(
input_binding_names=interpreter_result.input_names,
output_binding_names=interpreter_result.output_names,
)
else:
return TRTModule(
engine=interpreter_result.engine,
input_names=interpreter_result.input_names,
output_names=interpreter_result.output_names,
)
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def forward(self, x, y):
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
use_experimental_rt=True,
use_python_runtime=False,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -108,7 +108,7 @@ def forward(self, x, y):
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.add.Tensor"},
use_experimental_rt=True,
use_python_runtime=False,
debug=True,
)
optimized_model_results = optimized_model(*inputs).detach().cpu()
Expand Down Expand Up @@ -149,7 +149,7 @@ def forward(self, x, y):
inputs,
min_block_size=1,
pass_through_build_failures=True,
use_experimental_rt=True,
use_python_runtime=False,
optimization_level=4,
version_compatible=True,
max_aux_streams=5,
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from typing import Any, Union, Sequence, Dict
from torch_tensorrt import _Input, Device
from ..common_utils import use_python_runtime_parser


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,6 +103,9 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
if settings.debug:
logger.setLevel(logging.DEBUG)

# Parse input runtime specification
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)

logger.debug(f"Compiling with Settings:\n{settings}")

return settings
36 changes: 36 additions & 0 deletions py/torch_tensorrt/dynamo/common_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import logging
from typing import Optional


logger = logging.getLogger(__name__)


def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool:
"""Parses a user-provided input argument regarding Python runtime
Automatically handles cases where the user has not specified a runtime (None)
Returns True if the Python runtime should be used, False if the C++ runtime should be used
"""
using_python_runtime = use_python_runtime
reason = ""

# Runtime was manually specified by the user
if using_python_runtime is not None:
reason = "as requested by user"
# Runtime was not manually specified by the user, automatically detect runtime
else:
try:
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule

using_python_runtime = False
reason = "since C++ dependency was detected as present"
except ImportError:
using_python_runtime = True
reason = "since import failed, C++ dependency not installed"

logger.info(
f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime"
)

return using_python_runtime
36 changes: 20 additions & 16 deletions py/torch_tensorrt/dynamo/fx_ts_compat/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .lower_setting import LowerSetting
from .passes.lower_pass_manager_builder import LowerPassManagerBuilder
from .passes.pass_utils import PassFunc, validate_inference
from ..common_utils import use_python_runtime_parser
from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager
from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting

Expand Down Expand Up @@ -48,7 +49,7 @@ def compile(
save_timing_cache=False,
cuda_graph_batch_size=-1,
is_aten=False,
use_experimental_fx_rt=False,
use_python_runtime=None,
max_aux_streams=None,
version_compatible=False,
optimization_level=None,
Expand All @@ -70,7 +71,9 @@ def compile(
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
save_timing_cache: Update timing cache with current timing cache data if set to True.
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
Expand Down Expand Up @@ -111,6 +114,9 @@ def compile(
"Invalid device provided. Supported options: torch.device | torch_tensorrt.Device"
)

# Parse user-specification of which runtime to use
use_python_runtime = use_python_runtime_parser(use_python_runtime)

lower_setting = LowerSetting(
device=device,
min_block_size=min_block_size,
Expand All @@ -123,7 +129,7 @@ def compile(
save_timing_cache=save_timing_cache,
cuda_graph_batch_size=cuda_graph_batch_size,
is_aten=is_aten,
use_experimental_rt=use_experimental_fx_rt,
use_python_runtime=use_python_runtime,
max_aux_streams=max_aux_streams,
version_compatible=version_compatible,
optimization_level=optimization_level,
Expand Down Expand Up @@ -202,7 +208,7 @@ def default_split_function(
splitter_setting = TRTSplitterSetting()
splitter_setting.use_implicit_batch_dim = False
splitter_setting.min_block_size = lower_setting.min_block_size
splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt
splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime
splitter = TRTSplitter(model, inputs, settings=splitter_setting)
splitter.node_support_preview()
return splitter.generate_split_results()
Expand All @@ -224,9 +230,17 @@ def lower_pass(
"""
interpreter = create_trt_interpreter(lower_setting)
interp_res: TRTInterpreterResult = interpreter(mod, input, module_name)
if lower_setting.use_experimental_rt:
import io
if lower_setting.use_python_runtime:
trt_module = TRTModule(
engine=interp_res.engine,
input_names=interp_res.input_names,
output_names=interp_res.output_names,
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
)
return trt_module

else:
import io
from torch_tensorrt._Device import Device
from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule

Expand All @@ -240,16 +254,6 @@ def lower_pass(
input_binding_names=interp_res.input_names,
output_binding_names=interp_res.output_names,
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
# cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
)
return trt_module

else:
trt_module = TRTModule(
engine=interp_res.engine,
input_names=interp_res.input_names,
output_names=interp_res.output_names,
cuda_graph_batch_size=lower_setting.cuda_graph_batch_size,
)
return trt_module

Expand Down
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ class LowerSetting(LowerSettingBasic):
meaning all possible tactic sources.
correctness_atol: absolute tolerance for correctness check
correctness_rtol: relative tolerance for correctness check
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
use_python_runtime: Whether to use Python runtime or C++ runtime. None implies the user has not
selected a runtime, and the frontend will automatically do so on their behalf
max_aux_streams: max number of aux stream to use
version_compatible: enable version compatible feature
optimization_level: builder optimization level
Expand All @@ -95,7 +96,7 @@ class LowerSetting(LowerSettingBasic):
tactic_sources: Optional[int] = None
correctness_atol: float = 0.1
correctness_rtol: float = 0.1
use_experimental_rt: bool = False
use_python_runtime: Optional[bool] = None
max_aux_streams: Optional[int] = None
version_compatible: bool = False
optimization_level: Optional[int] = None
18 changes: 10 additions & 8 deletions py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,20 @@
def lower_mod_default(
mod: torch.fx.GraphModule,
inputs: Tensors,
use_experimental_rt: bool = False,
use_python_runtime: bool = False,
) -> TRTModule:
interp = TRTInterpreter(
mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True
)
interpreter_result = interp.run()
if use_experimental_rt:
if use_python_runtime:
res_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)

else:
import io

from torch_tensorrt._Device import Device
Expand All @@ -39,12 +46,7 @@ def lower_mod_default(
target_device=Device(f"cuda:{torch.cuda.current_device()}"),
# cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do
)
else:
res_mod = TRTModule(
interpreter_result.engine,
interpreter_result.input_names,
interpreter_result.output_names,
)

return res_mod


Expand Down

0 comments on commit f31e528

Please sign in to comment.