Skip to content

Commit

Permalink
feat: dynamic support for pixel_suffle and pixel_unshuffle (#3044)
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 authored Aug 7, 2024
1 parent 655ed6b commit 648772c
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 34 deletions.
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2801,7 +2801,9 @@ def aten_ops_reshape(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_shuffle.default)
@dynamo_tensorrt_converter(
torch.ops.aten.pixel_shuffle.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand All @@ -2824,7 +2826,9 @@ def aten_ops_pixel_shuffle(
)


@dynamo_tensorrt_converter(torch.ops.aten.pixel_unshuffle.default)
@dynamo_tensorrt_converter(
torch.ops.aten.pixel_unshuffle.default, supports_dynamic_shapes=True
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down
205 changes: 173 additions & 32 deletions py/torch_tensorrt/dynamo/conversion/impl/shuffle.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, Sequence, Union

import numpy as np
import tensorrt as trt
import torch_tensorrt.dynamo.conversion.impl as impl
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand All @@ -12,10 +13,9 @@
get_trt_tensor,
set_layer_name,
)
from torch_tensorrt.dynamo.conversion.impl.shape import get_shape_with_dynamic_shape
from torch_tensorrt.fx.types import TRTTensor

import tensorrt as trt


def reshape(
ctx: ConversionContext,
Expand Down Expand Up @@ -61,35 +61,106 @@ def pixel_shuffle(
input: TRTTensor,
upscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels // (upscale_factor**2)
out_height = in_height * upscale_factor
out_width = in_width * upscale_factor
new_shape = shape[:-3] + (
out_channels,
# Get input shape tensor
input_shape_tensor = get_shape_with_dynamic_shape(
ctx,
target,
source_ir,
name + "_shape",
input.shape,
input,
)

# Extract in_channels, in_height, and in_width from the input shape tensor
in_channels_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 3,), shape=(1,), stride=(1,)
).get_output(0)
in_height_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 2,), shape=(1,), stride=(1,)
).get_output(0)
in_width_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 1,), shape=(1,), stride=(1,)
).get_output(0)

# Calculate out_channels, out_height, and out_width as tensors
upscale_factor_sq = upscale_factor * upscale_factor
upscale_factor_tensor = get_trt_tensor(
ctx, upscale_factor, f"{name}_upscale_factor"
)
upscale_factor_sq_tensor = get_trt_tensor(
ctx, upscale_factor_sq, f"{name}_upscale_factor_sq"
)

out_channels_tensor = impl.elementwise.floor_divide(
ctx,
target,
source_ir,
f"{name}_out_channels_tensor",
in_channels_tensor,
upscale_factor_sq_tensor,
)
out_height_tensor = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_out_height_tensor",
in_height_tensor,
upscale_factor,
)
out_width_tensor = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_out_width_tensor",
in_width_tensor,
upscale_factor,
in_height,
in_width,
)

# Construct new shape tensor
new_shape_tensors = [
ctx.net.add_slice(
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
).get_output(0)
for i in range(len(input.shape) - 3)
]
new_shape_tensors += [
out_channels_tensor,
upscale_factor_tensor,
upscale_factor_tensor,
in_height_tensor,
in_width_tensor,
]

# Reshape tensor
reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
ctx, target, source_ir, f"{name}_reshape", input, new_shape_tensors
)
rank = len(shape)

# Permute shape
rank = len(input.shape)
permute_shape = list(range(rank))
permute_shape.insert(-2, rank)
permute_shape.insert(-1, rank + 1)
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)

# Construct output shape tensor
out_shape_tensors = [
ctx.net.add_slice(
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
).get_output(0)
for i in range(len(input.shape) - 3)
]
out_shape_tensors += [out_channels_tensor, out_height_tensor, out_width_tensor]

return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
f"{name}_reshape_out",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
out_shape_tensors,
)


Expand All @@ -101,39 +172,109 @@ def pixel_unshuffle(
input: TRTTensor,
downscale_factor: int,
) -> TRTTensor:
shape = input.shape
in_channels, in_height, in_width = shape[-3:]
out_channels = in_channels * (downscale_factor**2)
out_height = in_height // downscale_factor
out_width = in_width // downscale_factor
new_shape = shape[:-3] + (
in_channels,
out_height,
downscale_factor,
out_width,
downscale_factor,
# Get input shape tensor
input_shape_tensor = get_shape_with_dynamic_shape(
ctx,
target,
source_ir,
name + "_shape",
input.shape,
input,
)

# Extract in_channels, in_height, and in_width from the input shape tensor
in_channels_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 3,), shape=(1,), stride=(1,)
).get_output(0)
in_height_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 2,), shape=(1,), stride=(1,)
).get_output(0)
in_width_tensor = ctx.net.add_slice(
input_shape_tensor, start=(len(input.shape) - 1,), shape=(1,), stride=(1,)
).get_output(0)

# Calculate out_channels, out_height, and out_width as tensors
downscale_factor_sq = downscale_factor * downscale_factor
downscale_factor_tensor = get_trt_tensor(
ctx, downscale_factor, f"{name}_downscale_factor"
)
downscale_factor_sq_tensor = get_trt_tensor(
ctx, downscale_factor_sq, f"{name}_downscale_factor_sq"
)

out_channels_tensor = impl.elementwise.mul(
ctx,
target,
source_ir,
f"{name}_out_channels_tensor",
in_channels_tensor,
downscale_factor_sq_tensor,
)
out_height_tensor = impl.elementwise.floor_divide(
ctx,
target,
source_ir,
f"{name}_out_height_tensor",
in_height_tensor,
downscale_factor_tensor,
)
out_width_tensor = impl.elementwise.floor_divide(
ctx,
target,
source_ir,
f"{name}_out_width_tensor",
in_width_tensor,
downscale_factor_tensor,
)

# Construct new shape tensor
new_shape_tensors = [
ctx.net.add_slice(
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
).get_output(0)
for i in range(len(input.shape) - 3)
]
new_shape_tensors += [
in_channels_tensor,
out_height_tensor,
downscale_factor_tensor,
out_width_tensor,
downscale_factor_tensor,
]

reshaped_tensor = reshape(
ctx, target, source_ir, f"{name}_reshape1", input, new_shape
ctx, target, source_ir, f"{name}_reshape", input, new_shape_tensors
)
rank = len(new_shape)
permute_shape = tuple(range(rank - 5)) + (

# Permute shape
rank = len(new_shape_tensors)
permute_shape = list(range(rank - 5)) + [
rank - 5, # in_channels
rank - 3, # downscale_factor
rank - 1, # downscale_factor
rank - 4, # out_height
rank - 2, # out_width
)
]
permuted_tensor = impl.permutation.permute(
ctx, target, source_ir, f"{name}_permute", reshaped_tensor, permute_shape
)

# Construct output shape tensor
out_shape_tensors = [
ctx.net.add_slice(
input_shape_tensor, start=(i,), shape=(1,), stride=(1,)
).get_output(0)
for i in range(len(input.shape) - 3)
]
out_shape_tensors += [out_channels_tensor, out_height_tensor, out_width_tensor]

return reshape(
ctx,
target,
source_ir,
f"{name}_reshape2",
f"{name}_reshape_out",
permuted_tensor,
shape[:-3] + (out_channels, out_height, out_width),
out_shape_tensors,
)


Expand Down
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_shuffle_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -26,6 +27,38 @@ def forward(self, x):
inputs,
)

@parameterized.expand(
[
(
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
1,
),
]
)
def test_dynamic_shape_pixel_shuffle(
self, min_shape, opt_shape, max_shape, type, upscale_factor
):
class PixelShuffle(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.pixel_shuffle.default(x, upscale_factor)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(PixelShuffle(), input_specs)


if __name__ == "__main__":
run_tests()
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_pixel_unshuffle_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests
from torch_tensorrt import Input

from .harness import DispatchTestCase

Expand All @@ -24,6 +25,38 @@ def forward(self, x):
inputs,
)

@parameterized.expand(
[
(
(1, 1, 1),
(2, 2, 2),
(3, 3, 3),
torch.float,
1,
),
]
)
def test_dynamic_shape_pixel_unshuffle(
self, min_shape, opt_shape, max_shape, type, upscale_factor
):
class PixelUnshuffle(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten.pixel_unshuffle.default(x, upscale_factor)

input_specs = [
Input(
min_shape=min_shape,
opt_shape=opt_shape,
max_shape=max_shape,
dtype=type,
),
]

self.run_test_with_dynamic_shape(PixelUnshuffle(), input_specs)


if __name__ == "__main__":
run_tests()

0 comments on commit 648772c

Please sign in to comment.