Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] torch_tensorrt.ts.convert_method_to_trt_engine -> Unsupported operator #2102

Closed
proevgenii opened this issue Jul 12, 2023 · 6 comments
Assignees
Labels
bug Something isn't working No Activity

Comments

@proevgenii
Copy link

Bug Description

I want to save compiled model as TensorRT Engine
But while running torch_tensorrt.ts.convert_method_to_trt_engine(..., getting an Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], line 13
      1 import torch_tensorrt
      2 compile_spec = {
      3          "inputs": [torch_tensorrt.Input(
      4             max_shape=[256, 3, 224, 224],
   (...)
     11          
     12      }
---> 13 trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:241, in convert_method_to_trt_engine(module, method_name, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_long_and_double, calibrator)
    222     raise TypeError(
    223         "torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile"
    224     )
    226 compile_spec = {
    227     "inputs": inputs,
    228     "device": device,
   (...)
    238     "truncate_long_and_double": truncate_long_and_double,
    239 }
--> 241 engine_str = _C.convert_graph_to_trt_engine(
    242     module._c, method_name, _parse_compile_spec(compile_spec)
    243 )
    245 import io
    247 with io.BytesIO() as engine_bytes:

RuntimeError: [Error thrown at core/compiler.cpp:305] Expected conversion::VerifyConverterSupportForBlock(g->block()) to be true but got false
Not all operations in graph are supported by the compiler

To Reproduce

Steps to reproduce the behavior:

  1. Load ViT model from timm:
model_name = 'vit_base_patch32_224_clip_laion2b'
model = timm.create_model(model_name, pretrained = True, num_classes=num_cls, exportable=True, scriptable=True)
model_dct = torch.load(checkpoints_name, map_location = device)
model.load_state_dict(model_dct['state_dict_ema'])
model.eval().to(device)
  1. Compile model
model = model.eval().to(device)

inputs_trt = [
        torch_tensorrt.Input(
            max_shape=[64, 3, 224, 224],
            opt_shape=[32, 3, 224, 224],
            min_shape=[1, 3, 224, 224],
            dtype=torch.float32,)]
inputs_dummy = torch.rand((32, 3, 224, 224), dtype=torch.float32, device=device)
enabled_precisions = {torch.float,}

traced_model = torch.jit.trace(model, inputs_dummy)

optimized_model = torch_tensorrt.compile(
        traced_model, inputs=inputs_trt, enabled_precisions=enabled_precisions, truncate_long_and_double=True,)

Here I get some warnings

Output here
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
  1. Convert to TRT engine
compile_spec = {
         "inputs": [torch_tensorrt.Input(
            max_shape=[64, 3, 224, 224],
            opt_shape=[32, 3, 224, 224],
            min_shape=[1, 3, 224, 224],
            dtype=torch.float32,
        )],
         "enabled_precisions": torch.float,
        "truncate_long_and_double": True }
trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)
Output here
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[21], line 13
      1 import torch_tensorrt
      2 compile_spec = {
      3          "inputs": [torch_tensorrt.Input(
      4             max_shape=[256, 3, 224, 224],
   (...)
     11          
     12      }
---> 13 trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:241, in convert_method_to_trt_engine(module, method_name, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_long_and_double, calibrator)
    222     raise TypeError(
    223         "torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile"
    224     )
    226 compile_spec = {
    227     "inputs": inputs,
    228     "device": device,
   (...)
    238     "truncate_long_and_double": truncate_long_and_double,
    239 }
--> 241 engine_str = _C.convert_graph_to_trt_engine(
    242     module._c, method_name, _parse_compile_spec(compile_spec)
    243 )
    245 import io
    247 with io.BytesIO() as engine_bytes:

RuntimeError: [Error thrown at core/compiler.cpp:305] Expected conversion::VerifyConverterSupportForBlock(g->block()) to be true but got false
Not all operations in graph are supported by the compiler

Environment

  • I'm running this inside Nvidia container: nvcr.io/nvidia/pytorch:23.04-py3
  • Torch-TensorRT Version (e.g. 1.0.0): '1.4.0.dev0'
  • PyTorch Version (e.g. 1.0): '2.1.0a0+fe05266'
  • Python version: Python 3.8.10
  • CUDA version: 12.1
  • GPU models and configuration: Tesla T4
@proevgenii proevgenii added the bug Something isn't working label Jul 12, 2023
@narendasan
Copy link
Collaborator

When using torch_tensorrt.ts.convert_method_to_trt_engine the full model code must be supported in tensorrt as there will be no pytorch runtime available to run operations outside tensorrt once the engine has been constructed. @peri044 do you remember what in VIT isn't supported? The compiler should also be able to tell you what is missing. Not sure why its not doing this as part of the error (maybe this patch laneded after 23.04) but increasing the logging verbosity should show you what is missing

with torch_tensorrt.logging.debug():
    torch_tensorrt.compile(...

@proevgenii
Copy link
Author

I'm almost sure that the model is supported in TensorRT, because I converted this ViT model using the trtexec command and everything worked as it should

Output for `with torch_tensorrt.logging.debug():`
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 14
      2 compile_spec = {
      3          "inputs": [torch_tensorrt.Input(
      4             max_shape=[64, 3, 224, 224],
   (...)
     11          
     12      }
     13 with torch_tensorrt.logging.debug():
---> 14     trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:241, in convert_method_to_trt_engine(module, method_name, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_long_and_double, calibrator)
    222     raise TypeError(
    223         "torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile"
    224     )
    226 compile_spec = {
    227     "inputs": inputs,
    228     "device": device,
   (...)
    238     "truncate_long_and_double": truncate_long_and_double,
    239 }
--> 241 engine_str = _C.convert_graph_to_trt_engine(
    242     module._c, method_name, _parse_compile_spec(compile_spec)
    243 )
    245 import io
    247 with io.BytesIO() as engine_bytes:

RuntimeError: [Error thrown at core/compiler.cpp:305] Expected conversion::VerifyConverterSupportForBlock(g->block()) to be true but got false
Not all operations in graph are supported by the compiler

@peri044
Copy link
Collaborator

peri044 commented Jul 17, 2023

I converted this ViT model using the trtexec command and everything worked as it should
I assume you might have used ONNX-TRT for this. However, this is not the case with Torch-TRT.

  1. The example you provided uses dynamic shapes for VIT model. This is currently not supported in Torch-TRT. ViT graph has aten::size layer which gets passed to other layers. We recently added experimental support for ShapeTensors using --allow-shape-tensors where aten::size layer can now emit ShapeTensors instead of static values (in the case of static shaped inputs). However, not all converters support handling of ShapeTensors and in the case of ViT with dynamic shapes, there are a few converters which would fail when they receive ShapeTensors as inputs (coming from an aten::size layer).

  2. Seems like the debug logs are not getting logged which is strange.

@proevgenii
Copy link
Author

  1. Even if I'm using static input shape, like: inputs_trt = [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)] in
    torch_tensorrt.compile(.. and torch_tensorrt.ts.convert_method_to_trt_engine(...
    Still getting the error running this code:
compile_spec = {
         "inputs": [torch_tensorrt.Input(
             (1, 3, 224, 224),
            dtype=torch.float32,
        )],
         "enabled_precisions": torch.float,
        "truncate_long_and_double": True
         
     }
with torch_tensorrt.logging.debug():
    trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)
Log here
DEBUG: [Torch-TensorRT] - TensorRT Compile Spec: {
    "Inputs": [
Input(shape=(1,3,224,224,), dtype=Float, format=Contiguous/Linear/NCHW, tensor_domain=[0, 2))    ]
    "Enabled Precision": [Float, ]
    "TF32 Disabled": 0
    "Sparsity": 0
    "Refit": 0
    "Debug": 0
    "Device":  {
        "device_type": GPU
        "allow_gpu_fallback": False
        "gpu_id": 0
        "dla_core": -1
    }

    "Engine Capability": Default
    "Num Avg Timing Iters": 1
    "Workspace Size": 0
    "DLA SRAM Size": 1048576
    "DLA Local DRAM Size": 1073741824
    "DLA Global DRAM Size": 536870912
    "Truncate long and double": 1
    "Torch Fallback":  {
        "enabled": False
        "min_block_size": 1
        "forced_fallback_operators": [
        ]
        "forced_fallback_modules": [
        ]
    }
}
DEBUG: [Torch-TensorRT] - init_compile_spec with input vector
DEBUG: [Torch-TensorRT] - Settings requested for Lowering:
    torch_executed_modules: [
    ]
DEBUG: [Torch-TensorRT] - RemoveNOPs - Note: Removing operators that have no meaning in TRT
INFO: [Torch-TensorRT] - Lowered Graph: graph(%input_0 : Tensor):
  %self.__torch___timm_models_vision_transformer_VisionTransformer_trt_engine_ : __torch__.torch.classes.tensorrt.Engine = prim::Constant[value=object(0x855f39d0)]()
  %3 : Tensor[] = prim::ListConstruct(%input_0)
  %4 : Tensor[] = tensorrt::execute_engine(%3, %self.__torch___timm_models_vision_transformer_VisionTransformer_trt_engine_)
  %5 : Tensor = prim::ListUnpack(%4)
  return (%5)

ERROR: [Torch-TensorRT] - Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.
Unsupported operators listed below:
  - tensorrt::execute_engine(Tensor[] _0, __torch__.torch.classes.tensorrt.Engine _1) -> Tensor[] _0
You can either implement converters for these ops in your application or request implementation
https://www.github.com/nvidia/Torch-TensorRT/issues

In Module:

ERROR: [Torch-TensorRT] - Unsupported operator: tensorrt::execute_engine(Tensor[] _0, __torch__.torch.classes.tensorrt.Engine _1) -> Tensor[] _0

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[9], line 15
      2 compile_spec = {
      3          "inputs": [torch_tensorrt.Input(
      4             # max_shape=[64, 3, 224, 224],
   (...)
     12          
     13      }
     14 with torch_tensorrt.logging.debug():
---> 15     trt_engine = torch_tensorrt.ts.convert_method_to_trt_engine(optimized_model, "forward", **compile_spec)

File /usr/local/lib/python3.8/dist-packages/torch_tensorrt/ts/_compiler.py:241, in convert_method_to_trt_engine(module, method_name, inputs, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, truncate_long_and_double, calibrator)
    222     raise TypeError(
    223         "torch.jit.ScriptFunctions currently are not directly supported, wrap the function in a module to compile"
    224     )
    226 compile_spec = {
    227     "inputs": inputs,
    228     "device": device,
   (...)
    238     "truncate_long_and_double": truncate_long_and_double,
    239 }
--> 241 engine_str = _C.convert_graph_to_trt_engine(
    242     module._c, method_name, _parse_compile_spec(compile_spec)
    243 )
    245 import io
    247 with io.BytesIO() as engine_bytes:

RuntimeError: [Error thrown at core/compiler.cpp:305] Expected conversion::VerifyConverterSupportForBlock(g->block()) to be true but got false
Not all operations in graph are supported by the compiler

And it's interesting that when I'm using static input to compile model, it gives warnings about dynamic shapes:

inputs_trt = [torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float32)]
optimized_model = torch_tensorrt.compile(
        traced_model, inputs=inputs_trt, enabled_precisions=enabled_precisions, truncate_long_and_double=True,
    )
Output here
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - There may be undefined behavior using dynamic shape and aten::size
WARNING: [Torch-TensorRT] - Truncating weight (constant in the graph) from Float64 to Float32
  1. It's my mistake, previous time I missed part of the log

@proevgenii
Copy link
Author

@peri044 Any updates?

Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working No Activity
Projects
None yet
Development

No branches or pull requests

3 participants