Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] BF16 causing unspported numpy dtype error in create_constant #2902

Closed
HolyWu opened this issue Jun 9, 2024 · 4 comments · Fixed by #2974
Closed

🐛 [Bug] BF16 causing unspported numpy dtype error in create_constant #2902

HolyWu opened this issue Jun 9, 2024 · 4 comments · Fixed by #2974
Labels
bug Something isn't working

Comments

@HolyWu
Copy link
Contributor

HolyWu commented Jun 9, 2024

Bug Description

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo.utils:Using Default Torch-TRT Runtime (as requested by user)
INFO:torch_tensorrt.dynamo.utils:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.bf16: 10>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False)

DEBUG:torch_tensorrt.dynamo.backend.backends:Pre-AOT Autograd graph:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%l_x_,), kwargs = {})
    return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering._repair_input_aliasing:Inserted auxiliary clone nodes for placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
    return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering._remove_sym_nodes:Removed SymInt placeholders:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
    return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %clone_default : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%l_x_,), kwargs = {})
    %l__self___linear : [num_users=1] = call_module[target=L__self___linear](args = (%clone_default,), kwargs = {})
    return (l__self___linear,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Post-AOT Autograd graph:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})
    %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {})
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %clone, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removing node clone from graph, since it is a clone node which is the only user of placeholder arg0_1 and was inserted by the compiler.
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_param_constant0 : [num_users=1] = get_attr[target=_param_constant0]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%_param_constant0, [1, 0]), kwargs = {})
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
WARNING:torch_tensorrt.dynamo._compiler:Node _param_constant1 of op type get_attr does not have metadata. This could sometimes lead to undefined behavior.
WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Submodule name: _run_on_acc_0
 Input shapes: [(128, 20)]
 graph():
    %_param_constant1 : [num_users=1] = get_attr[target=_param_constant1]
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {})
    return addmm
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init CUDA: CPU +3, GPU +0, now: CPU 12984, GPU 1045 (MiB)
INFO:torch_tensorrt [TensorRT Conversion Context]:[MemUsageChange] Init builder kernel library: CPU +2657, GPU +308, now: CPU 15907, GPU 1353 (MiB)
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: arg0_1 [shape=[128, 20], dtype=DataType.BF16]
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node addmm (kind: aten.addmm.default, args: ('<torch.Tensor as np.ndarray [shape=(30,), dtype=float32]>', 'arg0_1 <tensorrt.ITensor [shape=(128, 20), dtype=DataType.BF16]>', '<torch.Tensor as np.ndarray [shape=(20, 30), dtype=float32]>'))
DEBUG:torch_tensorrt.dynamo.conversion.converter_utils:Freezing tensor addmm_constant_0 to TRT IConstantLayer
Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 29, in <module>
    optimized_model(*inputs)
  File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\eval_frame.py", line 432, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1552, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\nn\modules\module.py", line 1561, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 1115, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 947, in __call__
    result = self._inner_convert(
             ^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 471, in __call__
    return _compile(
           ^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_utils_internal.py", line 83, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_strobelight\compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 816, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 635, in compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\bytecode_transformation.py", line 1184, in transform_code_object
    transformations(instructions, code_options)
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 177, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\convert_frame.py", line 581, in transform
    tracer.run()
  File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2455, in run
    super().run()
  File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 897, in run
    while self.step():
          ^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 809, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2646, in RETURN_VALUE
    self._return(inst)
  File "C:\Python312\Lib\site-packages\torch\_dynamo\symbolic_convert.py", line 2631, in _return
    self.output.compile_subgraph(
  File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1097, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "C:\Python312\Lib\contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1314, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\utils.py", line 232, in time_wrapper
    r = func(*args, **kwargs)
        ^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1405, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "C:\Python312\Lib\site-packages\torch\_dynamo\output_graph.py", line 1386, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\_dynamo\repro\after_dynamo.py", line 128, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\__init__.py", line 1989, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 44, in torch_tensorrt_backend
    return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 52, in aot_torch_tensorrt_aten_backend
    return _pretraced_backend(gm, sample_inputs, settings)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\backend\backends.py", line 108, in _pretraced_backend
    trt_compiled = compile_module(
                   ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\_compiler.py", line 412, in compile_module
    trt_module = convert_module(
                 ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 106, in convert_module
    interpreter_result = interpret_module_to_result(module, inputs, settings)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_conversion.py", line 87, in interpret_module_to_result
    interpreter_result = interpreter.run()
                         ^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 310, in run
    super().run()
  File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 145, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 349, in run_node
    trt_node: torch.fx.Node = super().run_node(n)
                              ^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch\fx\interpreter.py", line 202, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\_TRTInterpreter.py", line 457, in call_function
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 469, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\aten_ops_converters.py", line 2714, in aten_ops_addmm
    return impl.addmm.addmm(
           ^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\addmm.py", line 24, in addmm
    mm = impl.matmul.matrix_multiply(ctx, target, source_ir, f"{name}_mm", mat1, mat2)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\impl\matmul.py", line 28, in matrix_multiply
    other = get_trt_tensor(
            ^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 328, in get_trt_tensor
    return create_constant(ctx, input_val, name, dtype, min_rank)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\dynamo\conversion\converter_utils.py", line 287, in create_constant
    value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Python312\Lib\site-packages\torch_tensorrt\_enums.py", line 279, in to
    raise TypeError("Unspported numpy dtype")
torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt_backend' raised:
TypeError: Unspported numpy dtype

While executing %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%_param_constant1, %arg0_1, %_frozen_param0), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x000001E4A3D7F930>: ((128, 20), torch.bfloat16, False, (20, 1), torch.contiguous_format, False, {})}})
Original traceback:
None

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

To Reproduce

import torch
import torch_tensorrt


class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(20, 30)

    def forward(self, x):
        return self.linear(x)


device = torch.device("cuda", 0)
model = MyModule().eval().to(device).bfloat16()
inputs = [torch.randn((128, 20), dtype=torch.bfloat16, device=device)]

with torch.inference_mode():
    optimized_model = torch_tensorrt.compile(
        model,
        ir="torch_compile",
        inputs=inputs,
        enabled_precisions={torch.bfloat16},
        debug=True,
        min_block_size=1,
        device=device,
    )

    optimized_model(*inputs)

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4.0.dev20240607+cu124
  • PyTorch Version (e.g. 1.0): 2.4.0.dev20240607+cu124
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 11
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.12.3
  • CUDA version: 12.4
  • GPU models and configuration: RTX 3050
  • Any other relevant information:

Additional context

Adding use_default=True argument to to(np.dtype) at

value, _enums.dtype._from(dtype).to(np.dtype) if dtype is not None else None
can make the compilation succeed. But I'm not sure if you'd like to solve it in the other way.

@HolyWu HolyWu added the bug Something isn't working label Jun 9, 2024
@narendasan
Copy link
Collaborator

This is likely the solution as there is no native bf16 type in numpy, so casting to float would be the best option.

@umarbutler
Copy link

+1 I just got this error on torch-tensorrt 2.4.0.

@umarbutler
Copy link

umarbutler commented Sep 25, 2024

@narendasan Same issue just arose from "torch_tensorrt\dynamo\conversion\impl\elementwise\base.py", line 129, in convert_binary_elementwise: rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype)).

Solution was to change to rhs_val = np.array([rhs_val], dtype=_enums.dtype._from(lhs_dtype).to(np.dtype, use_default = True)) (addition of use_default=True) for both rhs_val and lhs_val which occurs right under it.

@dgcnz
Copy link
Contributor

dgcnz commented Oct 16, 2024

This fixed also worked for me (torch_tensorrt==2.6.0.dev20241013+cu124), although it wasn't really worth it. Leaving the model at full fp32 precision and then compiling using enabled_precisions = {fp32, bf16} yielded a faster model anyway.

model's precision trt.enabled_precisions latency
fp32 fp32+bf16 17.261
bf16 fp32+bf16 22.913
bf16 bf16 22.938

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants