-
Notifications
You must be signed in to change notification settings - Fork 350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Unable to compile DPT Midas depth estimation model using Torch-TensorRT. #2083
Comments
Can you enable debug logging and share the logs? This would help us determine the root cause |
Hi @garg-aayush - thanks for the report. With a small change to the DPT model (adding I am able to reproduce the error you describe, and the problematic operator seems to be DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Evaluating %23 : int[] = prim::ListConstruct(%17, %21)
DEBUG: [Torch-TensorRT TorchScript Conversion Context] - Found the value to be: [24, 32]
INFO: [Torch-TensorRT TorchScript Conversion Context] - Adding Layer %posemb_grid.3 : Tensor = aten::upsample_bilinear2d(%25, %23, %26, %27)
...
RuntimeError: [Error thrown at core/conversion/var/Var.cpp:136] Expected isITensor() to be true but got false
Requested ITensor from Var, however Var type is c10::IValue A temporary workaround which makes compilation functional on my machine is to add these two lines to the compilation: trt_model_fp32 = torch_tensorrt.compile(model, inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
enabled_precisions = torch.float32, # Run with FP32
##### Added the below:
torch_executed_ops = ["aten::upsample_bilinear2d"],
truncate_long_and_double = True, |
@bowang007 @gs-olive Thank you for the reply. I made the suggested changes. However, I am getting the following error now. Note, I am working with latest Torch TensorRT (1.4.0). Error
Code import torch
from dpt.models import DPTDepthModel
import torch_tensorrt
# select device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_weights = 'weights/dpt_hybrid-midas-501f0c75.pt'
model = DPTDepthModel(
path=model_weights,
backbone="vitb_rn50_384",
non_negative=True,
enable_attention_hooks=False,
)
model.eval()
model.to(device)
with torch_tensorrt.logging.debug():
traced_model = torch.jit.trace(model,
[torch.randn((1,3,384,512)).to("cuda")]
)
trt_model_fp32 = torch_tensorrt.compile(
traced_model,
inputs = [torch_tensorrt.Input((1, 3, 384, 512), dtype=torch.float32)],
enabled_precisions = torch.float32, # Run with FP32
workspace_size = 1 << 22,
torch_executed_ops = ["aten::upsample_bilinear2d"],
truncate_long_and_double = True,
) I am attaching the full debug log for your reference. |
Hi @garg-aayush, with your provided code, I am able to reproduce the error. Based on this line in the debug logs:
It seems that the |
Hi @gs-olive, Thanks! It works now.
I saw in #1853 that you guys have fixed it. Am I correct? Do I need to upgrade my Torch TensorRT for it? |
Hi @garg-aayush - sounds good. Regarding the error, the fix in #1853 is still in progress and has not yet been merged. It is intended to address issues like the one you describe. |
I will wait for the merge. Thanks a lot for the help! |
Bug Description
I am trying to compile DPT-Hybrid depth estimation model using
Torch-TensorRT
. However, when I try to compile the traced DPT model usingtorch_tensorrt.compile
. I get the following errorCode snippet
Error
See below, for full error trace.
To Reproduce
Steps to reproduce
Expected behavior
I was hoping it would compile like the simple ResNet example.
Environment
conda
,pip
,libtorch
, source): pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118Full error trace
The text was updated successfully, but these errors were encountered: