Skip to content

Commit

Permalink
chore: dynamic shape support for pdist ops (#3068)
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna authored Aug 8, 2024
1 parent 19f671d commit 8ecc809
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 10 deletions.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3573,7 +3573,9 @@ def aten_ops_any(
)


@dynamo_tensorrt_converter(torch.ops.aten._pdist_forward.default)
@dynamo_tensorrt_converter(
torch.ops.aten._pdist_forward.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
208 changes: 199 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/impl/normalization/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,18 @@
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
cast_trt_tensor,
create_constant,
get_axes_for_reduce_op,
get_positive_dim,
get_trt_tensor,
to_numpy,
)
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
set_layer_name,
to_numpy,
)
from torch_tensorrt.dynamo.conversion.impl.cat import cat
from torch_tensorrt.dynamo.conversion.impl.elementwise.ops import ge
from torch_tensorrt.dynamo.conversion.impl.shape import shape as get_shape
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.fx.utils import get_dynamic_dims

Expand Down Expand Up @@ -417,20 +420,21 @@ def pdist(
) -> Union[TRTTensor, Sequence[TRTTensor]]:
shape = input.shape
# Extend input from shape [N, D] to [N, 1, D]
extend_input = impl.shuffle.reshape(
extend_input = impl.unsqueeze.unsqueeze(
ctx,
target,
source_ir,
f"{name}_reshape",
f"{name}_unsqueeze",
input,
shape=shape[0:1] + (1,) + shape[1:],
1,
)

# Expand the input from [N, 1, D] to [N, N, D]
x = impl.slice.expand(
ctx,
target,
source_ir,
f"{name}_sub",
f"{name}_expand",
extend_input,
(shape[0], shape[0]) + shape[1:],
)
Expand Down Expand Up @@ -482,8 +486,194 @@ def pdist(
raise RuntimeError(
f"p should between [0, inf], currently p={p} is not supported!"
)
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)
if shape[0] == DYNAMIC_DIM:
dim = get_shape(ctx, target, source_ir, f"{name}_get_shape", input, 0)
shuffle_layer = ctx.net.add_shuffle(dim)
shuffle_layer.reshape_dims = trt.Dims()
set_layer_name(shuffle_layer, target, f"{name}_shuffle", source_ir)
dim_tensor = shuffle_layer.get_output(0)
indices_tensor = tri_upper_indices(
ctx, target, source_ir, f"{name}_triu_indices", dim_tensor
)
gather_layer = ctx.net.add_gather_v2(
norm, indices_tensor, mode=trt.GatherMode.ND
)
set_layer_name(gather_layer, target, f"{name}_gather_layer", source_ir)
gather_layer.axis = 2
return gather_layer.get_output(0)
else:
indices = np.triu_indices(shape[0], k=1)
return impl.select.index(ctx, target, source_ir, f"{name}_index", norm, indices)


def tri_upper_indices(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
size_tensor: TRTTensor,
) -> TRTTensor:
"""
Return the indices for the upper-triangle part of a square size of matrix in a N-by-2 Tensor,
where the diagonal offset = 1. One loop is used to calculate the indices like below.
x = 0, y = 0, y_start = 1
out_size = size * (size - 1) // 2
for _ in range(out_size):
y_out.append(y_start + y)
x_out.append(x)
y += 1
if (y_start + y) >= size:
x += 1
y_start += 1
y = 0
Args:
ctx (ConversionContext): A ConversionContext containing the TensorRT network.
target (Target): Target of calling node.
source_ir (Optional[SourceIR]): SourceIR of calling converter.
name (str): Name of the calling layer.
size_tensor (TRTTensor): number of rows in the 2-D square matrix. scalar tensor.
Example:
if size_tensor is 4, it will return [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3]]
"""
constant_0 = create_constant(ctx, 0, f"{name}_zero", np.int32, 0)
constant_1 = create_constant(ctx, 1, f"{name}_one", np.int32, 0)
constant_2 = create_constant(ctx, 2, f"{name}_two", np.int32, 0)

size_minus_one = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_size_minus_one", size_tensor, constant_1
)

size_mult_prev = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_size_mult_prev", size_tensor, size_minus_one
)

num_loop = impl.elementwise.floor_divide(
ctx, target, source_ir, f"{name}_num_loop", size_mult_prev, constant_2
)

loop = ctx.net.add_loop()
loop.add_trip_limit(num_loop, trt.TripLimit.COUNT)

x_recurrence = loop.add_recurrence(constant_0)
set_layer_name(x_recurrence, target, f"{name}_x_recurrence", source_ir)
x_tensor = x_recurrence.get_output(0)

y_recurrence = loop.add_recurrence(constant_0)
set_layer_name(y_recurrence, target, f"{name}_y_recurrence", source_ir)
y_tensor = y_recurrence.get_output(0)

y_start_recurrence = loop.add_recurrence(constant_1)
set_layer_name(y_start_recurrence, target, f"{name}_y_start_recurrence", source_ir)
y_start_tensor = y_start_recurrence.get_output(0)

x_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_x_inc",
x_tensor,
constant_1,
)

y_current = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_current",
y_start_tensor,
y_tensor,
)

y_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_inc",
y_tensor,
constant_1,
)

next_y = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_next_y",
y_start_tensor,
y_inc,
)

y_start_inc = impl.elementwise.add(
ctx,
target,
source_ir,
f"{name}_y_start_inc",
y_start_tensor,
constant_1,
)
cond = ge(ctx, target, source_ir, f"{name}_cond", next_y, size_tensor)
x_output = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_x_output",
x_inc,
x_tensor,
cond,
)
x_recurrence.set_input(1, x_output)

y_start_current = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_y_start_current",
y_start_inc,
y_start_tensor,
cond,
)
y_start_recurrence.set_input(1, y_start_current)

y_val = impl.condition.select(
ctx,
target,
source_ir,
f"{name}_y_val",
constant_0,
y_inc,
cond,
)
y_recurrence.set_input(1, y_val)

loop_output_x = loop.add_loop_output(x_tensor, trt.LoopOutput.CONCATENATE)
loop_output_y = loop.add_loop_output(y_current, trt.LoopOutput.CONCATENATE)
loop_output_x.set_input(1, num_loop)
loop_output_y.set_input(1, num_loop)

# Cat two N tensors into 2 x N. [0, 0, 0], [1, 2, 3] -> [[0, 0, 0], [1, 2, 3]]
x_index = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_x_index", loop_output_x.get_output(0), (1, -1)
)
y_index = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_y_index", loop_output_y.get_output(0), (1, -1)
)

x_y_tensor = cat(
ctx,
target,
source_ir,
f"{name}_x_y_tensor",
[x_index, y_index],
0,
)

# Reshape 2 x N output to N x 2. [[0, 0, 0], [1, 2, 3]] -> [[0, 1], [0, 2], [0, 3]]
indices_tensor = ctx.net.add_shuffle(x_y_tensor)
set_layer_name(indices_tensor, target, f"{name}_indices_tensor", source_ir)
indices_tensor.first_transpose = trt.Permutation([1, 0])
indices_tensor.reshape_dims = (-1, 2)

return indices_tensor.get_output(0)


def cdist_forward(
Expand Down
58 changes: 58 additions & 0 deletions tests/py/dynamo/conversion/test_pdist_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand Down Expand Up @@ -32,5 +33,62 @@ def forward(self, input):
)


class TestDynamicShapePdistConverter(DispatchTestCase):
@parameterized.expand(
[
(
"dim0_dynamic_dim1_static_p_0",
(1, 4),
(2, 4),
(4, 4),
0,
),
(
"dim0_static_dim1_dynamic_p_1",
(3, 1),
(3, 2),
(3, 4),
1,
),
(
"dim0_dynamic_dim1_static_p_other",
(1, 5),
(2, 5),
(6, 5),
0.4,
),
(
"dim0_dynamic_dim1_dynamic_p_inf",
(1, 1),
(2, 2),
(5, 4),
float("inf"),
),
(
"dim0_dynamic_dim1_dynamic_p_other",
(2, 1),
(3, 2),
(4, 7),
1.7,
),
]
)
def test_pdist_float(self, _, min_shape, opt_shape, max_shape, p):
class Pdist(nn.Module):
def forward(self, input):
return torch.ops.aten._pdist_forward.default(input, p)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=torch.float,
),
]

self.run_test_with_dynamic_shape(Pdist(), input_specs)


if __name__ == "__main__":
run_tests()

0 comments on commit 8ecc809

Please sign in to comment.