Skip to content

Commit

Permalink
Fix: Layer norm Torchscript converter (#3062)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <naren@narendasan.com>
Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@nvidia.com>
Co-authored-by: Naren Dasan <naren@xnarendasan.com>
  • Loading branch information
narendasan and Naren Dasan authored Aug 15, 2024
1 parent 2589fdb commit 9bbba4f
Show file tree
Hide file tree
Showing 11 changed files with 161 additions and 261 deletions.
71 changes: 71 additions & 0 deletions core/conversion/converters/converter_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,77 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) {
return out;
}

nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
auto input_dims = in->getDimensions();
TORCHTRT_CHECK(
input_dims.nbDims <= expandedDims.nbDims,
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");

// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
int64_t offset = expandedDims.nbDims - 1 - i;
int64_t dim = input_dims.nbDims - 1 - offset;
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
int64_t targetSize = expandedDims.d[i];
// In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
if (targetSize != -1) {
if (size != targetSize) {
if (size != 1) {
TORCHTRT_THROW_ERROR(
"The expanded size of tensor (" << targetSize << ")"
<< " must match the existing size (" << size << ")"
<< " at dimension " << i);
}
}
} else {
// For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
// not [-1, 3, 4].
if (dim < 0) {
TORCHTRT_THROW_ERROR(
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
<< i);
} else {
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
expandedDims.d[i] = input_dims.d[dim];
}
}
}

auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = expandedDims.nbDims;
for (int64_t i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int64_t i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
auto reshape_layer = ctx->net->addShuffle(*in);
reshape_layer->setReshapeDimensions(reshape_dims);
in = reshape_layer->getOutput(0);
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
}

// Start the slicing from beginning of tensor since this is an expand layer
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));

// Set the stride of non singleton dimension to 1
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
strides_vec[i] = (in->getDimensions().d[i] != 1);
}

auto strides = util::toDims(c10::IntArrayRef(strides_vec));
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
LOG_DEBUG(ctx->logger, "Expand Tensor: " << in->getName());

return slice_layer->getOutput(0);
}

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
2 changes: 2 additions & 0 deletions core/conversion/converters/converter_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s);

nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b);

nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims);

} // namespace converters
} // namespace conversion
} // namespace core
Expand Down
82 changes: 9 additions & 73 deletions core/conversion/converters/impl/expand.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,78 +27,14 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe
}
}

bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) {
auto input_dims = in->getDimensions();
TORCHTRT_CHECK(
input_dims.nbDims <= expandedDims.nbDims,
"Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions");

// Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1]
for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) {
int64_t offset = expandedDims.nbDims - 1 - i;
int64_t dim = input_dims.nbDims - 1 - offset;
int64_t size = (dim >= 0) ? input_dims.d[dim] : 1;
int64_t targetSize = expandedDims.d[i];
// In expand layer passing -1 as the size for a dimension means not changing the size of that dimension.
if (targetSize != -1) {
if (size != targetSize) {
if (size != 1) {
TORCHTRT_THROW_ERROR(
"The expanded size of tensor (" << targetSize << ")"
<< " must match the existing size (" << size << ")"
<< " at dimension " << i);
}
}
} else {
// For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but
// not [-1, 3, 4].
if (dim < 0) {
TORCHTRT_THROW_ERROR(
"The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension "
<< i);
} else {
// in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4)
expandedDims.d[i] = input_dims.d[dim];
}
}
}

auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims;
if (num_expand_dims > 0) {
nvinfer1::Dims reshape_dims;
reshape_dims.nbDims = expandedDims.nbDims;
for (int64_t i = 0; i < num_expand_dims; i++) {
reshape_dims.d[i] = 1;
}
for (int64_t i = 0; i < input_dims.nbDims; i++) {
reshape_dims.d[num_expand_dims + i] = input_dims.d[i];
}
// Add a reshape layer to expand dims
auto reshape_layer = ctx->net->addShuffle(*in);
reshape_layer->setReshapeDimensions(reshape_dims);
in = reshape_layer->getOutput(0);
LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims);
}

// Start the slicing from beginning of tensor since this is an expand layer
std::vector<int64_t> start_vec(expandedDims.nbDims, 0);
auto start_offset = util::toDims(c10::IntArrayRef(start_vec));

// Set the stride of non singleton dimension to 1
std::vector<int64_t> strides_vec(expandedDims.nbDims, 0);
for (int64_t i = 0; i < expandedDims.nbDims; i++) {
strides_vec[i] = (in->getDimensions().d[i] != 1);
}

auto strides = util::toDims(c10::IntArrayRef(strides_vec));
// Slice layer does the expansion in TRT. Desired output size is specified by expandedDims
auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides);
slice_layer->setName(util::node_info(n).c_str());

auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0));

bool add_expand_static(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* in,
nvinfer1::Dims expandedDims) {
auto expand_out = add_expand(ctx, in, expandedDims);
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], expand_out);
LOG_DEBUG("Expand layer output tensor shape: " << out->getDimensions());

return true;
}

Expand Down Expand Up @@ -209,7 +145,7 @@ auto expand_registrations TORCHTRT_UNUSED =
auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size);
return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true);
} else {
return add_expand(ctx, n, in, expandedDims);
return add_expand_static(ctx, n, in, expandedDims);
}
}})
.pattern(
Expand All @@ -223,7 +159,7 @@ auto expand_registrations TORCHTRT_UNUSED =
if (ctx->input_is_dynamic) {
return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false);
} else {
return add_expand(ctx, n, in, targetDims);
return add_expand_static(ctx, n, in, targetDims);
}
}})
.pattern(
Expand Down
47 changes: 7 additions & 40 deletions core/conversion/converters/impl/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,41 +10,6 @@ namespace converters {
namespace impl {
namespace {

nvinfer1::ITensor* broadcast(
ConversionCtx* ctx,
const torch::jit::Node* n,
nvinfer1::ITensor* to_broadcast,
const int nbDims,
const std::string& tag) {
auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims;
TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target");
if (to_broadcast_nbdims == nbDims) {
return to_broadcast;
}
auto shape_layer = ctx->net->addShape(*to_broadcast);
TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n);
shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str());
auto shape_layer_out = shape_layer->getOutput(0);

auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32));
auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor);

std::vector<nvinfer1::ITensor*> to_concat = {extra_dims_itensor, shape_layer_out};
auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size());
TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n);
concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str());
auto target_shape = concat_layer->getOutput(0);

auto shuffle_layer = ctx->net->addShuffle(*to_broadcast);
TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n);
shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str());
shuffle_layer->setInput(1, *target_shape);
auto output = shuffle_layer->getOutput(0);
LOG_DEBUG(
"Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions());
return output;
}

auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({
R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta,
float eps, bool cudnn_enabled) -> (Tensor))SIG",
Expand All @@ -62,20 +27,22 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()

nvinfer1::ITensor* gamma = nullptr;
if (args[2].IValue()->isNone()) {
auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
auto gamma_torch_tensor =
torch::ones(input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType())));
gamma = tensor_to_const(ctx, gamma_torch_tensor);
} else {
gamma = args[2].ITensorOrFreeze(ctx);
gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma");
gamma = add_expand(ctx, gamma, input_shape);
}

nvinfer1::ITensor* beta = nullptr;
if (args[3].IValue()->isNone()) {
auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32));
auto beta_torch_tensor = torch::zeros(
input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType())));
beta = tensor_to_const(ctx, beta_torch_tensor);
} else {
beta = args[3].ITensorOrFreeze(ctx);
beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta");
beta = add_expand(ctx, beta, input_shape);
}

auto eps = args[4].unwrapToDouble();
Expand All @@ -84,7 +51,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns()
TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n);
normalize_layer->setName(util::node_info(n).c_str());
normalize_layer->setEpsilon(eps);
normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT);
normalize_layer->setComputePrecision(input->getType());
auto normalized = normalize_layer->getOutput(0);

ctx->AssociateValueAndTensor(n->outputs()[0], normalized);
Expand Down
2 changes: 1 addition & 1 deletion core/lowering/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -142,11 +142,11 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
passes::SiluToSigmoidMultipication(g);
passes::RemoveSingleUse0DTensors(g);
passes::RemoveUnnecessaryCasts(g);
passes::UnpackScaledDotProductAttention(g);
passes::ReplaceAtenInt(g);
if (lower_info.converting_to_trt_engine) {
passes::RemoveCollectionCast(g);
}
passes::UnpackScaledDotProductAttention(g);
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
Expand Down
Loading

0 comments on commit 9bbba4f

Please sign in to comment.