Skip to content

Commit

Permalink
feat: dynamic shape support for pad ops (#3045)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Aug 13, 2024
1 parent 6321710 commit f32c7a9
Show file tree
Hide file tree
Showing 3 changed files with 504 additions and 83 deletions.
34 changes: 25 additions & 9 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,7 +2943,9 @@ def aten_ops_addmm(
)


@dynamo_tensorrt_converter(torch.ops.aten.constant_pad_nd.default)
@dynamo_tensorrt_converter(
torch.ops.aten.constant_pad_nd.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2967,9 +2969,15 @@ def aten_ops_constant_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.reflection_pad3d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.reflection_pad1d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.reflection_pad2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.reflection_pad3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2992,9 +3000,15 @@ def aten_ops_reflection_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten.replication_pad1d.default)
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad2d.default)
@dynamo_tensorrt_converter(torch.ops.aten.replication_pad3d.default)
@dynamo_tensorrt_converter(
torch.ops.aten.replication_pad1d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.replication_pad2d.default, supports_dynamic_shapes=True
)
@dynamo_tensorrt_converter(
torch.ops.aten.replication_pad3d.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -3017,7 +3031,9 @@ def aten_ops_replication_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten._pad_circular.default)
@dynamo_tensorrt_converter(
torch.ops.aten._pad_circular.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -3040,7 +3056,7 @@ def aten_ops_circular_pad(
)


@dynamo_tensorrt_converter(torch.ops.aten.pad.default)
@dynamo_tensorrt_converter(torch.ops.aten.pad.default, supports_dynamic_shapes=True)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
208 changes: 134 additions & 74 deletions py/torch_tensorrt/dynamo/conversion/impl/pad.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
from torch.fx.node import Target
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor
from torch_tensorrt.fx.converters.converter_utils import (
has_dynamic_shape,
from torch_tensorrt.dynamo.conversion.converter_utils import (
get_trt_tensor,
set_layer_name,
)
from torch_tensorrt.fx.types import TRTTensor
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.dynamo.types import TRTTensor

"""
Note: IPaddingLayer is deprecated in TensorRT 8.2 and will be removed in TensorRT 10.0.
Expand All @@ -18,39 +20,109 @@
"""


def constant_padNd(
def get_padded_shape_tensors(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
value: Union[int, float] = 0,
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(pad) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
f"Trying to pad last {len(pad) // 2} dimensions but the input only has {rank} dimensions."
)

input_shape_tensor = get_shape_with_dynamic_shape(
ctx,
target,
source_ir,
name + "_input_shape",
input.shape,
input,
)
padded_shape_tensor = input_shape_tensor

start_list = [0] * rank
new_shape = list(input.shape)
for i in range(len(pad) // 2):
dim_index = rank - (i + 1)
pad_before = pad[i * 2]
pad_after = pad[i * 2 + 1]

for i in range(0, len(pad) // 2):
start_list[-i - 1] = -pad[i * 2]
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
pad_sum = get_trt_tensor(
ctx, pad_before + pad_after, f"{name}_pad_sum_{i}", dtype=np.int32
)
dim_shape = ctx.net.add_slice(
input_shape_tensor,
start=(dim_index,),
shape=(1,),
stride=(1,),
).get_output(0)

new_dim_shape = impl.elementwise.add(
ctx, target, source_ir, f"{name}_shape_dim_{i}", dim_shape, pad_sum
)
start_list[dim_index] = -pad_before

slices = []
for j in range(rank):
if j == dim_index:
slices.append(new_dim_shape)
else:
slices.append(
ctx.net.add_slice(
padded_shape_tensor,
start=(j,),
shape=(1,),
stride=(1,),
).get_output(0)
)
padded_shape_tensor = impl.cat.cat(
ctx, target, source_ir, f"{name}_cat", slices, 0
)

start_indices_tensor = get_trt_tensor(
ctx,
np.array(start_list, dtype=np.int32),
f"{name}_start_indices_tensor",
dtype=np.int32,
)

return start_indices_tensor, padded_shape_tensor


def constant_padNd(
ctx: ConversionContext,
target: Union[Target, str],
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
value: Union[int, float] = 0,
) -> TRTTensor:

rank = len(input.shape)

start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
ctx, target, source_ir, name, input, pad
)

stride_list = [1] * rank
stride_tensor = get_trt_tensor(
ctx,
np.array(stride_list, dtype=np.int32),
f"{name}_stride_tensor",
dtype=np.int32,
)

layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_indices_tensor)
layer.set_input(2, padded_shape_tensor)
layer.set_input(3, stride_tensor)

value_const = get_trt_tensor(ctx, value, f"{name}_value", input.dtype)
layer.set_input(4, value_const)
layer.mode = trt.SampleMode.FILL
Expand All @@ -67,30 +139,26 @@ def reflection_padNd(
input: TRTTensor,
padding: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(padding) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(padding) // 2):
start_list[-i - 1] = -padding[i * 2]
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
ctx, target, source_ir, name, input, padding
)

stride_list = [1] * rank
stride_tensor = get_trt_tensor(
ctx,
np.array(stride_list, dtype=np.int32),
f"{name}_stride_tensor",
dtype=np.int32,
)

layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_indices_tensor)
layer.set_input(2, padded_shape_tensor)
layer.set_input(3, stride_tensor)
layer.mode = trt.SampleMode.REFLECT

set_layer_name(layer, target, name, source_ir)
Expand All @@ -105,30 +173,26 @@ def replication_padNd(
input: TRTTensor,
padding: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(padding) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(padding) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(padding) // 2):
start_list[-i - 1] = -padding[i * 2]
new_shape[-i - 1] += padding[i * 2] + padding[i * 2 + 1]
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
ctx, target, source_ir, name, input, padding
)

stride_list = [1] * rank
stride_tensor = get_trt_tensor(
ctx,
np.array(stride_list, dtype=np.int32),
f"{name}_stride_tensor",
dtype=np.int32,
)

layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_indices_tensor)
layer.set_input(2, padded_shape_tensor)
layer.set_input(3, stride_tensor)
layer.mode = trt.SampleMode.CLAMP

set_layer_name(layer, target, name, source_ir)
Expand All @@ -141,32 +205,28 @@ def circular_padNd(
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
pad: Sequence[int],
padding: Sequence[int],
) -> TRTTensor:
if has_dynamic_shape(input.shape):
assert input.shape[1] != -1, "Channel dim can't be dynamic for padding."

rank = len(input.shape)

if len(pad) // 2 > rank:
raise RuntimeError(
f"Trying to pad last {len(pad) // 2} dimension but the input only has {rank} dimension."
)

start_list = [0] * rank
new_shape = list(input.shape)

for i in range(0, len(pad) // 2):
start_list[-i - 1] = -pad[i * 2]
new_shape[-i - 1] += pad[i * 2] + pad[i * 2 + 1]
start_indices_tensor, padded_shape_tensor = get_padded_shape_tensors(
ctx, target, source_ir, name, input, padding
)

stride_list = [1] * rank
stride_tensor = get_trt_tensor(
ctx,
np.array(stride_list, dtype=np.int32),
f"{name}_stride_tensor",
dtype=np.int32,
)

layer = ctx.net.add_slice(
input,
start=tuple(start_list),
shape=tuple(new_shape),
stride=tuple(stride_list),
input, start=trt.Dims(), shape=trt.Dims(), stride=trt.Dims()
)
layer.set_input(1, start_indices_tensor)
layer.set_input(2, padded_shape_tensor)
layer.set_input(3, stride_tensor)
layer.mode = trt.SampleMode.WRAP

set_layer_name(layer, target, name, source_ir)
Expand Down
Loading

0 comments on commit f32c7a9

Please sign in to comment.