diff --git a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py index 80c9b89977..8359bc62fb 100644 --- a/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/_TorchTensorRTModule.py @@ -69,9 +69,6 @@ def __init__( ) """ - logger.warning( - "TorchTensorRTModule should be considered experimental stability, APIs are subject to change. Note: TorchTensorRTModule only supports engines built with explict batch" - ) super(TorchTensorRTModule, self).__init__() if not isinstance(serialized_engine, bytearray): diff --git a/py/torch_tensorrt/dynamo/backend/__init__.py b/py/torch_tensorrt/dynamo/backend/__init__.py index da2191f0cd..38e60fce41 100644 --- a/py/torch_tensorrt/dynamo/backend/__init__.py +++ b/py/torch_tensorrt/dynamo/backend/__init__.py @@ -19,7 +19,7 @@ MAX_AUX_STREAMS, VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, - USE_EXPERIMENTAL_RT, + USE_PYTHON_RUNTIME, ) @@ -52,7 +52,7 @@ def compile( max_aux_streams=MAX_AUX_STREAMS, version_compatible=VERSION_COMPATIBLE, optimization_level=OPTIMIZATION_LEVEL, - use_experimental_rt=USE_EXPERIMENTAL_RT, + use_python_runtime=USE_PYTHON_RUNTIME, **kwargs, ): if debug: @@ -65,11 +65,6 @@ def compile( + "torch_executed_ops, pass_through_build_failures}" ) - if "use_experimental_fx_rt" in kwargs: - use_experimental_rt = kwargs["use_experimental_fx_rt"] - - logger.info(f"Using {'C++' if use_experimental_rt else 'Python'} TRT Runtime") - if not isinstance(inputs, collections.abc.Sequence): inputs = [inputs] @@ -107,7 +102,7 @@ def compile( max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, - use_experimental_rt=use_experimental_rt, + use_python_runtime=use_python_runtime, **kwargs, ) @@ -134,7 +129,7 @@ def create_backend( max_aux_streams: Optional[int] = MAX_AUX_STREAMS, version_compatible: bool = VERSION_COMPATIBLE, optimization_level: Optional[int] = OPTIMIZATION_LEVEL, - use_experimental_rt: bool = USE_EXPERIMENTAL_RT, + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, **kwargs, ): """Create torch.compile backend given specified arguments @@ -150,7 +145,9 @@ def create_backend( version_compatible: Provide version forward-compatibility for engine plan files optimization_level: Builder optimization 0-5, higher levels imply longer build time, searching for more optimization options. TRT defaults to 3 - use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines + use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None Returns: Backend for torch.compile """ @@ -165,5 +162,6 @@ def create_backend( max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, - use_experimental_rt=use_experimental_rt, + use_python_runtime=use_python_runtime, + **kwargs, ) diff --git a/py/torch_tensorrt/dynamo/backend/_defaults.py b/py/torch_tensorrt/dynamo/backend/_defaults.py index 286c60c2fa..0afbc60f8c 100644 --- a/py/torch_tensorrt/dynamo/backend/_defaults.py +++ b/py/torch_tensorrt/dynamo/backend/_defaults.py @@ -9,4 +9,4 @@ MAX_AUX_STREAMS = None VERSION_COMPATIBLE = False OPTIMIZATION_LEVEL = None -USE_EXPERIMENTAL_RT = False +USE_PYTHON_RUNTIME = None diff --git a/py/torch_tensorrt/dynamo/backend/_settings.py b/py/torch_tensorrt/dynamo/backend/_settings.py index 7ec4cc596e..d074a6b079 100644 --- a/py/torch_tensorrt/dynamo/backend/_settings.py +++ b/py/torch_tensorrt/dynamo/backend/_settings.py @@ -11,11 +11,11 @@ MAX_AUX_STREAMS, VERSION_COMPATIBLE, OPTIMIZATION_LEVEL, - USE_EXPERIMENTAL_RT, + USE_PYTHON_RUNTIME, ) -@dataclass(frozen=True) +@dataclass class CompilationSettings: precision: LowerPrecision = PRECISION debug: bool = DEBUG @@ -26,4 +26,4 @@ class CompilationSettings: max_aux_streams: Optional[int] = MAX_AUX_STREAMS version_compatible: bool = VERSION_COMPATIBLE optimization_level: Optional[int] = OPTIMIZATION_LEVEL - use_experimental_rt: bool = USE_EXPERIMENTAL_RT + use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index 1db043d55e..425fb0941e 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -55,7 +55,14 @@ def convert_module( optimization_level=settings.optimization_level, ) - if settings.use_experimental_rt: + if settings.use_python_runtime: + return TRTModule( + engine=interpreter_result.engine, + input_names=interpreter_result.input_names, + output_names=interpreter_result.output_names, + ) + + else: from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule with io.BytesIO() as engine_bytes: @@ -67,9 +74,3 @@ def convert_module( input_binding_names=interpreter_result.input_names, output_binding_names=interpreter_result.output_names, ) - else: - return TRTModule( - engine=interpreter_result.engine, - input_names=interpreter_result.input_names, - output_names=interpreter_result.output_names, - ) diff --git a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py index 625a9be1c2..2af251adbc 100644 --- a/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py +++ b/py/torch_tensorrt/dynamo/backend/test/test_backend_compiler.py @@ -40,7 +40,7 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, - use_experimental_rt=True, + use_python_runtime=False, debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -108,7 +108,7 @@ def forward(self, x, y): min_block_size=1, pass_through_build_failures=True, torch_executed_ops={"torch.ops.aten.add.Tensor"}, - use_experimental_rt=True, + use_python_runtime=False, debug=True, ) optimized_model_results = optimized_model(*inputs).detach().cpu() @@ -149,7 +149,7 @@ def forward(self, x, y): inputs, min_block_size=1, pass_through_build_failures=True, - use_experimental_rt=True, + use_python_runtime=False, optimization_level=4, version_compatible=True, max_aux_streams=5, diff --git a/py/torch_tensorrt/dynamo/backend/utils.py b/py/torch_tensorrt/dynamo/backend/utils.py index 9396373790..23a1cd4795 100644 --- a/py/torch_tensorrt/dynamo/backend/utils.py +++ b/py/torch_tensorrt/dynamo/backend/utils.py @@ -5,6 +5,7 @@ from torch_tensorrt.dynamo.backend._settings import CompilationSettings from typing import Any, Union, Sequence, Dict from torch_tensorrt import _Input, Device +from ..common_utils import use_python_runtime_parser logger = logging.getLogger(__name__) @@ -102,6 +103,9 @@ def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: if settings.debug: logger.setLevel(logging.DEBUG) + # Parse input runtime specification + settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime) + logger.debug(f"Compiling with Settings:\n{settings}") return settings diff --git a/py/torch_tensorrt/dynamo/common_utils/__init__.py b/py/torch_tensorrt/dynamo/common_utils/__init__.py index e69de29bb2..de0ce0a48a 100644 --- a/py/torch_tensorrt/dynamo/common_utils/__init__.py +++ b/py/torch_tensorrt/dynamo/common_utils/__init__.py @@ -0,0 +1,36 @@ +import logging +from typing import Optional + + +logger = logging.getLogger(__name__) + + +def use_python_runtime_parser(use_python_runtime: Optional[bool] = None) -> bool: + """Parses a user-provided input argument regarding Python runtime + + Automatically handles cases where the user has not specified a runtime (None) + + Returns True if the Python runtime should be used, False if the C++ runtime should be used + """ + using_python_runtime = use_python_runtime + reason = "" + + # Runtime was manually specified by the user + if using_python_runtime is not None: + reason = "as requested by user" + # Runtime was not manually specified by the user, automatically detect runtime + else: + try: + from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule + + using_python_runtime = False + reason = "since C++ dependency was detected as present" + except ImportError: + using_python_runtime = True + reason = "since import failed, C++ dependency not installed" + + logger.info( + f"Using {'Python' if using_python_runtime else 'C++'} {reason} TRT Runtime" + ) + + return using_python_runtime diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py index 9af39b88ca..c0f1ae7870 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower.py @@ -14,6 +14,7 @@ from .lower_setting import LowerSetting from .passes.lower_pass_manager_builder import LowerPassManagerBuilder from .passes.pass_utils import PassFunc, validate_inference +from ..common_utils import use_python_runtime_parser from torch_tensorrt.fx.tools.timing_cache_utils import TimingCacheManager from torch_tensorrt.fx.tools.trt_splitter import TRTSplitter, TRTSplitterSetting @@ -48,7 +49,7 @@ def compile( save_timing_cache=False, cuda_graph_batch_size=-1, is_aten=False, - use_experimental_fx_rt=False, + use_python_runtime=None, max_aux_streams=None, version_compatible=False, optimization_level=None, @@ -70,7 +71,9 @@ def compile( timing_cache_prefix: Timing cache file name for timing cache used by fx2trt. save_timing_cache: Update timing cache with current timing cache data if set to True. cuda_graph_batch_size: Cuda graph batch size, default to be -1. - use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime + based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the + argument as None max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level @@ -111,6 +114,9 @@ def compile( "Invalid device provided. Supported options: torch.device | torch_tensorrt.Device" ) + # Parse user-specification of which runtime to use + use_python_runtime = use_python_runtime_parser(use_python_runtime) + lower_setting = LowerSetting( device=device, min_block_size=min_block_size, @@ -123,7 +129,7 @@ def compile( save_timing_cache=save_timing_cache, cuda_graph_batch_size=cuda_graph_batch_size, is_aten=is_aten, - use_experimental_rt=use_experimental_fx_rt, + use_python_runtime=use_python_runtime, max_aux_streams=max_aux_streams, version_compatible=version_compatible, optimization_level=optimization_level, @@ -202,7 +208,7 @@ def default_split_function( splitter_setting = TRTSplitterSetting() splitter_setting.use_implicit_batch_dim = False splitter_setting.min_block_size = lower_setting.min_block_size - splitter_setting.use_experimental_rt = lower_setting.use_experimental_rt + splitter_setting.use_experimental_rt = not lower_setting.use_python_runtime splitter = TRTSplitter(model, inputs, settings=splitter_setting) splitter.node_support_preview() return splitter.generate_split_results() @@ -224,9 +230,17 @@ def lower_pass( """ interpreter = create_trt_interpreter(lower_setting) interp_res: TRTInterpreterResult = interpreter(mod, input, module_name) - if lower_setting.use_experimental_rt: - import io + if lower_setting.use_python_runtime: + trt_module = TRTModule( + engine=interp_res.engine, + input_names=interp_res.input_names, + output_names=interp_res.output_names, + cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, + ) + return trt_module + else: + import io from torch_tensorrt._Device import Device from torch_tensorrt.dynamo._TorchTensorRTModule import TorchTensorRTModule @@ -240,16 +254,6 @@ def lower_pass( input_binding_names=interp_res.input_names, output_binding_names=interp_res.output_names, target_device=Device(f"cuda:{torch.cuda.current_device()}"), - # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do - ) - return trt_module - - else: - trt_module = TRTModule( - engine=interp_res.engine, - input_names=interp_res.input_names, - output_names=interp_res.output_names, - cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, ) return trt_module diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py index 64a67d1cc2..9301a2cd90 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py @@ -68,7 +68,8 @@ class LowerSetting(LowerSettingBasic): meaning all possible tactic sources. correctness_atol: absolute tolerance for correctness check correctness_rtol: relative tolerance for correctness check - use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++). + use_python_runtime: Whether to use Python runtime or C++ runtime. None implies the user has not + selected a runtime, and the frontend will automatically do so on their behalf max_aux_streams: max number of aux stream to use version_compatible: enable version compatible feature optimization_level: builder optimization level @@ -95,7 +96,7 @@ class LowerSetting(LowerSettingBasic): tactic_sources: Optional[int] = None correctness_atol: float = 0.1 correctness_rtol: float = 0.1 - use_experimental_rt: bool = False + use_python_runtime: Optional[bool] = None max_aux_streams: Optional[int] = None version_compatible: bool = False optimization_level: Optional[int] = None diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py index 11371a92f9..bfb1964de9 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/tools/trt_minimizer.py @@ -15,13 +15,20 @@ def lower_mod_default( mod: torch.fx.GraphModule, inputs: Tensors, - use_experimental_rt: bool = False, + use_python_runtime: bool = False, ) -> TRTModule: interp = TRTInterpreter( mod, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True ) interpreter_result = interp.run() - if use_experimental_rt: + if use_python_runtime: + res_mod = TRTModule( + interpreter_result.engine, + interpreter_result.input_names, + interpreter_result.output_names, + ) + + else: import io from torch_tensorrt._Device import Device @@ -39,12 +46,7 @@ def lower_mod_default( target_device=Device(f"cuda:{torch.cuda.current_device()}"), # cuda_graph_batch_size=lower_setting.cuda_graph_batch_size, # NOTE: Not sure what this is supposed to do ) - else: - res_mod = TRTModule( - interpreter_result.engine, - interpreter_result.input_names, - interpreter_result.output_names, - ) + return res_mod diff --git a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py index 462fe04e70..f34aad6caf 100644 --- a/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py +++ b/py/torch_tensorrt/dynamo/test/test_dynamo_backend.py @@ -31,6 +31,8 @@ def test_resnet18(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -62,6 +64,8 @@ def test_mobilenet_v2(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -93,6 +97,8 @@ def test_efficientnet_b0(ir): "enabled_precisions": {torch.float}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -133,6 +139,8 @@ def test_bert_base_uncased(ir): "truncate_long_and_double": True, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec) @@ -168,6 +176,8 @@ def test_resnet18_half(ir): "enabled_precisions": {torch.half}, "ir": ir, "pass_through_build_failures": True, + "optimization_level": 1, + "min_block_size": 8, } trt_mod = torchtrt.compile(model, **compile_spec)