diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 852f28b074..e3aaaec175 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -3573,7 +3573,9 @@ def aten_ops_any( ) -@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default) +@dynamo_tensorrt_converter( + torch.ops.aten._pdist_forward.default, supports_dynamic_shapes=True +) @enforce_tensor_types( { 0: (TRTTensor,), diff --git a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py index 5c6943bf5b..c9599c7476 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py @@ -10,15 +10,18 @@ from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( cast_trt_tensor, + create_constant, get_axes_for_reduce_op, get_positive_dim, get_trt_tensor, - to_numpy, -) -from torch_tensorrt.fx.converters.converter_utils import ( has_dynamic_shape, set_layer_name, + to_numpy, ) +from torch_tensorrt.dynamo.conversion.impl.cat import cat +from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge +from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape +from torch_tensorrt.dynamo.utils import DYNAMIC_DIM from torch_tensorrt.fx.types import TRTTensor from torch_tensorrt.fx.utils import get_dynamic_dims @@ -417,20 +420,21 @@ def pdist( ) -> Union[TRTTensor, Sequence[TRTTensor]]: shape = input.shape # Extend input from shape [N, D] to [N, 1, D] - extend_input = impl.shuffle.reshape( + extend_input = impl.unsqueeze.unsqueeze( ctx, target, source_ir, - f"{name}_reshape", + f"{name}_unsqueeze", input, - shape=shape[0:1] + (1,) + shape[1:], + 1, ) + # Expand the input from [N, 1, D] to [N, N, D] x = impl.slice.expand( ctx, target, source_ir, - f"{name}_sub", + f"{name}_expand", extend_input, (shape[0], shape[0]) + shape[1:], ) @@ -482,8 +486,194 @@ def pdist( raise RuntimeError( f"p should between [0, inf], currently p={p} is not supported!" ) - indices = np.triu_indices(shape[0], k=1) - return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) + if shape[0] == DYNAMIC_DIM: + dim = get_shape(ctx, target, source_ir, f"{name}_get_shape", input, 0) + shuffle_layer = ctx.net.add_shuffle(dim) + shuffle_layer.reshape_dims = trt.Dims() + set_layer_name(shuffle_layer, target, f"{name}_shuffle", source_ir) + dim_tensor = shuffle_layer.get_output(0) + indices_tensor = tri_upper_indices( + ctx, target, source_ir, f"{name}_triu_indices", dim_tensor + ) + gather_layer = ctx.net.add_gather_v2( + norm, indices_tensor, mode=trt.GatherMode.ND + ) + set_layer_name(gather_layer, target, f"{name}_gather_layer", source_ir) + gather_layer.axis = 2 + return gather_layer.get_output(0) + else: + indices = np.triu_indices(shape[0], k=1) + return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices) + + +def tri_upper_indices( + ctx: ConversionContext, + target: Target, + source_ir: Optional[SourceIR], + name: str, + size_tensor: TRTTensor, +) -> TRTTensor: + """ + Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor, + where the diagonal offset = 1. One loop is used to calculate the indices like below. + x = 0, y = 0, y_start = 1 + out_size = size * (size - 1) // 2 + for _ in range(out_size): + y_out.append(y_start + y) + x_out.append(x) + y += 1 + if (y_start + y) >= size: + x += 1 + y_start += 1 + y = 0 + Args: + ctx (ConversionContext): A ConversionContext containing the TensorRT network. + target (Target): Target of calling node. + source_ir (Optional[SourceIR]): SourceIR of calling converter. + name (str): Name of the calling layer. + size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor. + + Example: + if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]] + """ + constant_0 = create_constant(ctx, 0, f"{name}_zero", np.int32, 0) + constant_1 = create_constant(ctx, 1, f"{name}_one", np.int32, 0) + constant_2 = create_constant(ctx, 2, f"{name}_two", np.int32, 0) + + size_minus_one = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_size_minus_one", size_tensor, constant_1 + ) + + size_mult_prev = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_size_mult_prev", size_tensor, size_minus_one + ) + + num_loop = impl.elementwise.floor_divide( + ctx, target, source_ir, f"{name}_num_loop", size_mult_prev, constant_2 + ) + + loop = ctx.net.add_loop() + loop.add_trip_limit(num_loop, trt.TripLimit.COUNT) + + x_recurrence = loop.add_recurrence(constant_0) + set_layer_name(x_recurrence, target, f"{name}_x_recurrence", source_ir) + x_tensor = x_recurrence.get_output(0) + + y_recurrence = loop.add_recurrence(constant_0) + set_layer_name(y_recurrence, target, f"{name}_y_recurrence", source_ir) + y_tensor = y_recurrence.get_output(0) + + y_start_recurrence = loop.add_recurrence(constant_1) + set_layer_name(y_start_recurrence, target, f"{name}_y_start_recurrence", source_ir) + y_start_tensor = y_start_recurrence.get_output(0) + + x_inc = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_x_inc", + x_tensor, + constant_1, + ) + + y_current = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_y_current", + y_start_tensor, + y_tensor, + ) + + y_inc = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_y_inc", + y_tensor, + constant_1, + ) + + next_y = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_next_y", + y_start_tensor, + y_inc, + ) + + y_start_inc = impl.elementwise.add( + ctx, + target, + source_ir, + f"{name}_y_start_inc", + y_start_tensor, + constant_1, + ) + cond = ge(ctx, target, source_ir, f"{name}_cond", next_y, size_tensor) + x_output = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_x_output", + x_inc, + x_tensor, + cond, + ) + x_recurrence.set_input(1, x_output) + + y_start_current = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_y_start_current", + y_start_inc, + y_start_tensor, + cond, + ) + y_start_recurrence.set_input(1, y_start_current) + + y_val = impl.condition.select( + ctx, + target, + source_ir, + f"{name}_y_val", + constant_0, + y_inc, + cond, + ) + y_recurrence.set_input(1, y_val) + + loop_output_x = loop.add_loop_output(x_tensor, trt.LoopOutput.CONCATENATE) + loop_output_y = loop.add_loop_output(y_current, trt.LoopOutput.CONCATENATE) + loop_output_x.set_input(1, num_loop) + loop_output_y.set_input(1, num_loop) + + # Cat two N tensors into 2 x N. [0, 0, 0], [1, 2, 3] -> [[0, 0, 0], [1, 2, 3]] + x_index = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_x_index", loop_output_x.get_output(0), (1, -1) + ) + y_index = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_y_index", loop_output_y.get_output(0), (1, -1) + ) + + x_y_tensor = cat( + ctx, + target, + source_ir, + f"{name}_x_y_tensor", + [x_index, y_index], + 0, + ) + + # Reshape 2 x N output to N x 2. [[0, 0, 0], [1, 2, 3]] -> [[0, 1], [0, 2], [0, 3]] + indices_tensor = ctx.net.add_shuffle(x_y_tensor) + set_layer_name(indices_tensor, target, f"{name}_indices_tensor", source_ir) + indices_tensor.first_transpose = trt.Permutation([1, 0]) + indices_tensor.reshape_dims = (-1, 2) + + return indices_tensor.get_output(0) def cdist_forward( diff --git a/tests/py/dynamo/conversion/test_pdist_aten.py b/tests/py/dynamo/conversion/test_pdist_aten.py index 67e547faf2..a7843780b8 100644 --- a/tests/py/dynamo/conversion/test_pdist_aten.py +++ b/tests/py/dynamo/conversion/test_pdist_aten.py @@ -2,6 +2,7 @@ import torch.nn as nn from parameterized import parameterized from torch.testing._internal.common_utils import run_tests +from torch_tensorrt import Input from .harness import DispatchTestCase @@ -32,5 +33,62 @@ def forward(self, input): ) +class TestDynamicShapePdistConverter(DispatchTestCase): + @parameterized.expand( + [ + ( + "dim0_dynamic_dim1_static_p_0", + (1, 4), + (2, 4), + (4, 4), + 0, + ), + ( + "dim0_static_dim1_dynamic_p_1", + (3, 1), + (3, 2), + (3, 4), + 1, + ), + ( + "dim0_dynamic_dim1_static_p_other", + (1, 5), + (2, 5), + (6, 5), + 0.4, + ), + ( + "dim0_dynamic_dim1_dynamic_p_inf", + (1, 1), + (2, 2), + (5, 4), + float("inf"), + ), + ( + "dim0_dynamic_dim1_dynamic_p_other", + (2, 1), + (3, 2), + (4, 7), + 1.7, + ), + ] + ) + def test_pdist_float(self, _, min_shape, opt_shape, max_shape, p): + class Pdist(nn.Module): + def forward(self, input): + return torch.ops.aten._pdist_forward.default(input, p) + + input_specs = [ + Input( + min_shape=min_shape, + opt_shape=opt_shape, + max_shape=max_shape, + dtype=torch.float, + ), + ] + + self.run_test_with_dynamic_shape(Pdist(), input_specs) + + if __name__ == "__main__": run_tests()