Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Unable to compile DPT Midas depth estimation model using Torch-TensorRT. #2083

Closed
garg-aayush opened this issue Jul 7, 2023 · 7 comments
Assignees
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working component: converters Issues re: Specific op converters

Comments

@garg-aayush
Copy link

Bug Description

I am trying to compile DPT-Hybrid depth estimation model using Torch-TensorRT. However, when I try to compile the traced DPT model using torch_tensorrt.compile. I get the following error

Code snippet

import torch_tensorrt

trt_model_fp32 = torch_tensorrt.compile(traced_model, inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
    enabled_precisions = torch.float32, # Run with FP32
    workspace_size = 1 << 22)

Error

RuntimeError: [Error thrown at core/conversion/var/Var.cpp:136] Expected isITensor() to be true but got false
Requested ITensor from Var, however Var type is c10::IValue

See below, for full error trace.

To Reproduce

Steps to reproduce

  1. Clone the repo and download the weights
git clone https://github.com/isl-org/DPT
wget https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt
  1. Install TensorRT (8.5.1.7) and python packages for Torch-tensorRT (1.4.0)
  2. Install all the required dependencies for DPT (basically torch and timm)
  3. Run the following code
import torch
import util.io

import torch_tensorrt
from dpt.models import DPTDepthModel
from dpt.midas_net import MidasNet_large

model_weights = 'weights/dpt_hybrid-midas-501f0c75.pt'
model_type = 'dpt_hybrid'
net_w = net_h = 384

model = DPTDepthModel(
    path=model_weights,
    backbone="vitb_rn50_384",
    non_negative=True,
    enable_attention_hooks=False,
)
model.eval()
model.to(device)
traced_model = torch.jit.trace(model, [torch.randn((1,3,384,512)).to("cuda")])

trt_model_fp32 = torch_tensorrt.compile(traced_model, inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
    enabled_precisions = torch.float32, # Run with FP32
    workspace_size = 1 << 22

Expected behavior

I was hoping it would compile like the simple ResNet example.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.4.0):
  • PyTorch Version (e.g. 2.0.1):
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 20.4
  • How you installed PyTorch (conda, pip, libtorch, source): pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
  • Python version: 3.9.17
  • CUDA version: 11.8
  • GPU models and configuration: RTX4090

Full error trace

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[14], line 3
      1 import torch_tensorrt
----> 3 trt_model_fp32 = torch_tensorrt.compile(traced_model, inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
      4     enabled_precisions = torch.float32, # Run with FP32
      5     workspace_size = 1 << 22
      6 )

File [~/miniconda3/envs/torch_tensorrt/lib/python3.9/site-packages/torch_tensorrt/_compile.py:133](https://vscode-remote+ssh-002dremote-002b192-002e168-002e68-002e62.vscode-resource.vscode-cdn.net/home/aayush/Projects/TensorRT_Examples/DPT/~/miniconda3/envs/torch_tensorrt/lib/python3.9/site-packages/torch_tensorrt/_compile.py:133), in compile(module, ir, inputs, enabled_precisions, **kwargs)
    128         logging.log(
    129             logging.Level.Info,
    130             "Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript",
    131         )
    132         ts_mod = torch.jit.script(module)
--> 133     return torch_tensorrt.ts.compile(
    134         ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs
    135     )
    136 elif target_ir == _IRType.fx:
    137     if (
    138         torch.float16 in enabled_precisions
    139         or torch_tensorrt.dtype.half in enabled_precisions
    140     ):

File [~/miniconda3/envs/torch_tensorrt/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py:139](https://vscode-remote+ssh-002dremote-002b192-002e168-002e68-002e62.vscode-resource.vscode-cdn.net/home/aayush/Projects/TensorRT_Examples/DPT/~/miniconda3/envs/torch_tensorrt/lib/python3.9/site-packages/torch_tensorrt/ts/_compiler.py:139), in compile(module, inputs, input_signature, device, disable_tf32, sparse_weights, enabled_precisions, refit, debug, capability, num_avg_timing_iters, workspace_size, dla_sram_size, dla_local_dram_size, dla_global_dram_size, calibrator, truncate_long_and_double, require_full_compilation, min_block_size, torch_executed_ops, torch_executed_modules, allow_shape_tensors)
    112     raise ValueError(
    113         f"require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: {torch_executed_ops}, torch_executed_modules: {torch_executed_modules}"
    114     )
    116 spec = {
    117     "inputs": inputs,
    118     "input_signature": input_signature,
   (...)
    136     "allow_shape_tensors": allow_shape_tensors,
    137 }
--> 139 compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
    140 compiled_module = torch.jit._recursive.wrap_cpp_module(compiled_cpp_mod)
    141 return compiled_module

RuntimeError: [Error thrown at core/conversion/var/Var.cpp:136] Expected isITensor() to be true but got false
Requested ITensor from Var, however Var type is c10::IValue
@garg-aayush garg-aayush added the bug Something isn't working label Jul 7, 2023
@gs-olive gs-olive self-assigned this Jul 7, 2023
@narendasan
Copy link
Collaborator

Can you enable debug logging and share the logs? This would help us determine the root cause

@gs-olive gs-olive added the bug: triaged [verified] We can replicate the bug label Jul 7, 2023
@gs-olive
Copy link
Collaborator

gs-olive commented Jul 7, 2023

Hi @garg-aayush - thanks for the report. With a small change to the DPT model (adding int casts here), I was able to trace and begin model compilation using the latest main of Torch-TRT.

I am able to reproduce the error you describe, and the problematic operator seems to be aten::upsample_bilinear2d. Relevant logs:

DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %23 : int[] = prim::ListConstruct(%17, %21)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: [24, 32]
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %posemb_grid.3 : Tensor = aten::upsample_bilinear2d(%25, %23, %26, %27)

...

RuntimeError: [Error thrown at core/conversion/var/Var.cpp:136] Expected isITensor() to be true but got false
Requested ITensor from Var, however Var type is c10::IValue

A temporary workaround which makes compilation functional on my machine is to add these two lines to the compilation:

trt_model_fp32 = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
    enabled_precisions = torch.float32, # Run with FP32

##### Added the below:
    torch_executed_ops = ["aten::upsample_bilinear2d"],
    truncate_long_and_double = True,

@gs-olive gs-olive added the component: converters Issues re: Specific op converters label Jul 7, 2023
@garg-aayush
Copy link
Author

@bowang007 @gs-olive Thank you for the reply. I made the suggested changes. However, I am getting the following error now.

Note, I am working with latest Torch TensorRT (1.4.0).

Error

RuntimeError: [Error thrown at core/conversion/conversionctx/ConversionCtx.cpp:169] Building serialized network failed in TensorRT

Code

import torch
from dpt.models import DPTDepthModel
import torch_tensorrt

# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_weights = 'weights/dpt_hybrid-midas-501f0c75.pt'

model = DPTDepthModel(
    path=model_weights,
    backbone="vitb_rn50_384",
    non_negative=True,
    enable_attention_hooks=False,
)
model.eval()
model.to(device)


with torch_tensorrt.logging.debug():
    traced_model = torch.jit.trace(model,
                                    [torch.randn((1,3,384,512)).to("cuda")]
                                    )   
    trt_model_fp32 = torch_tensorrt.compile(
        traced_model,
        inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
        enabled_precisions = torch.float32, # Run with FP32
        workspace_size = 1 << 22, 
        torch_executed_ops = ["aten::upsample_bilinear2d"],
        truncate_long_and_double = True,
        )

I am attaching the full debug log for your reference.
debug_log.txt

@gs-olive
Copy link
Collaborator

gs-olive commented Jul 10, 2023

Hi @garg-aayush, with your provided code, I am able to reproduce the error. Based on this line in the debug logs:

ERROR: [Torch-TensorRT TorchScript Conversion Context] - 4: Could not find any implementation for node {ForeignNode[%15 : Tensor = aten::flatten(%8, %16, %17) # /home/aayush/Projects/TensorRT_Examples/DPT/dpt/vit.py:188:0 + %x.167 : Tensor = aten::transpose(%15, %44, %16) # /home/aayush/Projects/TensorRT_Examples/DPT/dpt/vit.py:188:0...(Unnamed Layer* 40) [Shuffle]]} due to insufficient workspace. See verbose log for requested sizes.

It seems that the workspace_size is insufficient for one of the operations. Could you try the run again and omit the line workspace_size = 1 << 22, from compilation? With this omission, the model compiles successfully on my machine.

@garg-aayush
Copy link
Author

Hi @gs-olive, Thanks! It works now.
However, one more question. I am trying to compile with float16/half. It compiles fine. However, when I run the inference. I get this weird error

Expected input tensors to have type Half, found type float

I saw in #1853 that you guys have fixed it. Am I correct? Do I need to upgrade my Torch TensorRT for it?

@gs-olive
Copy link
Collaborator

Hi @garg-aayush - sounds good. Regarding the error, the fix in #1853 is still in progress and has not yet been merged. It is intended to address issues like the one you describe.

@garg-aayush
Copy link
Author

I will wait for the merge. Thanks a lot for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug: triaged [verified] We can replicate the bug bug Something isn't working component: converters Issues re: Specific op converters
Projects
None yet
Development

No branches or pull requests

4 participants