Skip to content

Commit

Permalink
Merge pull request #2097 from andi4191/anurag.dixit/aten_unflatten
Browse files Browse the repository at this point in the history
feat: Added support for aten::unflatten converter
  • Loading branch information
peri044 authored Aug 2, 2023
2 parents 76800bc + a47b5fe commit 3c49608
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 1 deletion.
98 changes: 98 additions & 0 deletions core/conversion/converters/impl/shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,104 @@ static auto shuffle_registrations TORCHTRT_UNUSED =
return true;
}})
.pattern(
{"aten::unflatten.int(Tensor self, int dim, int[] sizes) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
auto dim = args[1].unwrapToInt();
auto in_shape = util::toVec(in->getDimensions());
std::vector<int64_t> new_shape;
nvinfer1::ITensor* shape_tensor;
if (ctx->input_is_dynamic) {
/*
* In case the dim is negative
* If the dim in negative range is larger than in_shape,
* then it should run into index out of bound error as expected
*/
if (dim < 0) {
dim = in_shape.size() + dim;
}
std::cout << "Dynamic shape case" << std::endl;
LOG_DEBUG("Using dynamic version of reshape layer");
if (args[2].isITensorList()) {
std::cout << "isTensorList case" << std::endl;
LOG_DEBUG("Shape tensor is an ITensorList");
auto expand_shape = args[2].unwrapToITensorList();
auto shape_layer = ctx->net->addShape(*in);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
auto shape_1d_tensor = shape_layer->getOutput(0);

std::vector<int> before_dim_indices_vector(dim);
std::iota(before_dim_indices_vector.begin(), before_dim_indices_vector.end(), 0);

nvinfer1::ITensor* before_dim_gather_out = nullptr;
if (before_dim_indices_vector.size()) {
at::Tensor before_dim_indices = torch::tensor(before_dim_indices_vector).to(torch::kI32);
auto before_dim_indices_out = converters::tensor_to_const(ctx, before_dim_indices);
auto before_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *before_dim_indices_out, 0);
TORCHTRT_CHECK(before_dim_gather_layer, "Unable to create gather layer from node: " << *n);
before_dim_gather_out = before_dim_gather_layer->getOutput(0);
}

std::vector<int> after_dim_indices_vector(in_shape.size() - (dim + 1));
std::iota(after_dim_indices_vector.begin(), after_dim_indices_vector.end(), dim + 1);

nvinfer1::ITensor* after_dim_gather_out = nullptr;
if (after_dim_indices_vector.size()) {
at::Tensor after_dim_indices = torch::tensor(after_dim_indices_vector).to(torch::kI32);
auto after_dim_indices_out = converters::tensor_to_const(ctx, after_dim_indices);
auto after_dim_gather_layer = ctx->net->addGather(*shape_1d_tensor, *after_dim_indices_out, 0);
TORCHTRT_CHECK(after_dim_gather_layer, "Unable to create gather layer from node: " << *n);
after_dim_gather_out = after_dim_gather_layer->getOutput(0);
}

std::vector<nvinfer1::ITensor*> shape_tensors;
if (before_dim_gather_out) {
shape_tensors.push_back(before_dim_gather_out);
}
for (auto new_shape_tensor : expand_shape) {
shape_tensors.push_back(new_shape_tensor);
}
if (after_dim_gather_out) {
shape_tensors.push_back(after_dim_gather_out);
}

auto shape_cat_layer = ctx->net->addConcatenation(shape_tensors.data(), shape_tensors.size());
TORCHTRT_CHECK(shape_cat_layer, "Unable to create cat layer from node: " << *n);
shape_tensor = shape_cat_layer->getOutput(0);
LOG_DEBUG("Shape tensor shape: " << shape_tensor->getDimensions());
} else if (args[2].isIntList()) {
auto shape_vec = args[2].unwrapToIntList().vec();
// New shape
new_shape.insert(new_shape.end(), in_shape.begin(), in_shape.begin() + dim);
new_shape.insert(new_shape.end(), shape_vec.begin(), shape_vec.end());
new_shape.insert(new_shape.end(), in_shape.begin() + dim + 1, in_shape.end());

shape_tensor = tensor_to_const(ctx, torch::tensor(new_shape).to(torch::kI32));
} else {
LOG_ERROR(
"Invalid IValue type of " << args[2].IValue()->type()
<< " detected for shape tensor from node: " << *n);
}
} else {
new_shape =
torch::unflatten(torch::rand(in_shape), dim, args[2].unwrapToIntList().vec()).sizes().vec();
}
auto shuffle = ctx->net->addShuffle(*in);
shuffle->setName(util::node_info(n).c_str());
TORCHTRT_CHECK(shuffle, "Unable to create shuffle layer from node: " << *n);

if (ctx->input_is_dynamic) {
shuffle->setInput(1, *shape_tensor);
} else {
shuffle->setReshapeDimensions(util::toDims(new_shape));
}

auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle->getOutput(0));
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());

return true;
}})
.pattern(
{"aten::reshape(Tensor self, int[] shape) -> (Tensor)",
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
auto in = args[0].ITensorOrFreeze(ctx);
Expand Down
52 changes: 52 additions & 0 deletions tests/core/conversion/converters/test_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,55 @@ TEST(Converters, ATenPixelShuffle5DConvertsCorrectly) {

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenUnflattenConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=512]()
%4 : int = prim::Constant[value=1]()
%5 : int = prim::Constant[value=1]()
%6 : int[] = prim::ListConstruct(%3, %4, %5)
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
return (%7))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});

auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenUnflattenNegativeDimConvertsCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=-1]()
%3 : int = prim::Constant[value=512]()
%4 : int = prim::Constant[value=1]()
%5 : int = prim::Constant[value=1]()
%6 : int[] = prim::ListConstruct(%3, %4, %5)
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
return (%7))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(0, 5, {1, 512}, {at::kCUDA});
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});

auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in});

in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in});

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}
79 changes: 78 additions & 1 deletion tests/cpp/test_dynamic_size.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,4 +124,81 @@ TEST(Converters, ATenResizeGetItemDynShapeMulCorrectly) {
auto trt = trt_results[0].reshape(jit_results[0].sizes());

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
}
}

TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectly) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%2 : int = prim::Constant[value=1]()
%3 : int = aten::size(%x.1, %2)
%4 : int = prim::Constant[value=256]()
%5 : int = prim::Constant[value=2]()
%6 : int[] = prim::ListConstruct(%4, %5)
%7 : Tensor = aten::unflatten(%x.1, %2, %6)
return (%7))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(0, 10, {1, 512, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyFirstDim) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%1 : int = prim::Constant[value=0]()
%2 : int = prim::Constant[value=1]()
%3 : int = aten::size(%x.1, %1)
%6 : int[] = prim::ListConstruct(%2, %2, %3, %2, %2)
%7 : Tensor = aten::unflatten(%x.1, %1, %6)
return (%7))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(0, 10, {64, 512, 1}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

TEST(Converters, ATenUnflattenDynShapeITensorShapeCorrectlyLastDim) {
const auto graph = R"IR(
graph(%x.1 : Tensor):
%1 : int = prim::Constant[value=2]()
%2 : int = prim::Constant[value=1]()
%3 : int = aten::size(%x.1, %1)
%5 : int = prim::Constant[value=2]()
%6 : int[] = prim::ListConstruct(%3, %2, %2)
%7 : Tensor = aten::unflatten(%x.1, %5, %6)
return (%7))IR";
auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto in = at::randint(0, 10, {1, 512, 9}, {at::kCUDA});

auto jit_in = at::clone(in);
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in});

auto trt_in = at::clone(in);
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {in}, true);

ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
}

0 comments on commit 3c49608

Please sign in to comment.