From 9bbba4fec4882b226ca7a349416b8417e9917758 Mon Sep 17 00:00:00 2001 From: Naren Dasan <1790613+narendasan@users.noreply.github.com> Date: Thu, 15 Aug 2024 10:39:58 -0600 Subject: [PATCH] Fix: Layer norm Torchscript converter (#3062) Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan Co-authored-by: Naren Dasan --- core/conversion/converters/converter_util.cpp | 71 ++++++++++ core/conversion/converters/converter_util.h | 2 + core/conversion/converters/impl/expand.cpp | 82 ++--------- .../conversion/converters/impl/layer_norm.cpp | 47 +------ core/lowering/lowering.cpp | 2 +- .../passes/remove_unnecessary_casts.cpp | 129 ++++++++---------- tests/cpp/test_compiled_modules.cpp | 11 +- tests/modules/hub.py | 11 +- tests/py/ts/models/custom_models.py | 28 ---- tests/py/ts/models/test_models.py | 38 ------ .../test_multiple_registered_engines.py | 1 - 11 files changed, 161 insertions(+), 261 deletions(-) delete mode 100644 tests/py/ts/models/custom_models.py diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 39afe9945f..fda153195e 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -438,6 +438,77 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s) { return out; } +nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) { + auto input_dims = in->getDimensions(); + TORCHTRT_CHECK( + input_dims.nbDims <= expandedDims.nbDims, + "Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions"); + + // Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1] + for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) { + int64_t offset = expandedDims.nbDims - 1 - i; + int64_t dim = input_dims.nbDims - 1 - offset; + int64_t size = (dim >= 0) ? input_dims.d[dim] : 1; + int64_t targetSize = expandedDims.d[i]; + // In expand layer passing -1 as the size for a dimension means not changing the size of that dimension. + if (targetSize != -1) { + if (size != targetSize) { + if (size != 1) { + TORCHTRT_THROW_ERROR( + "The expanded size of tensor (" << targetSize << ")" + << " must match the existing size (" << size << ")" + << " at dimension " << i); + } + } + } else { + // For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but + // not [-1, 3, 4]. + if (dim < 0) { + TORCHTRT_THROW_ERROR( + "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension " + << i); + } else { + // in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4) + expandedDims.d[i] = input_dims.d[dim]; + } + } + } + + auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims; + if (num_expand_dims > 0) { + nvinfer1::Dims reshape_dims; + reshape_dims.nbDims = expandedDims.nbDims; + for (int64_t i = 0; i < num_expand_dims; i++) { + reshape_dims.d[i] = 1; + } + for (int64_t i = 0; i < input_dims.nbDims; i++) { + reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; + } + // Add a reshape layer to expand dims + auto reshape_layer = ctx->net->addShuffle(*in); + reshape_layer->setReshapeDimensions(reshape_dims); + in = reshape_layer->getOutput(0); + LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); + } + + // Start the slicing from beginning of tensor since this is an expand layer + std::vector start_vec(expandedDims.nbDims, 0); + auto start_offset = util::toDims(c10::IntArrayRef(start_vec)); + + // Set the stride of non singleton dimension to 1 + std::vector strides_vec(expandedDims.nbDims, 0); + for (int64_t i = 0; i < expandedDims.nbDims; i++) { + strides_vec[i] = (in->getDimensions().d[i] != 1); + } + + auto strides = util::toDims(c10::IntArrayRef(strides_vec)); + // Slice layer does the expansion in TRT. Desired output size is specified by expandedDims + auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides); + LOG_DEBUG(ctx->logger, "Expand Tensor: " << in->getName()); + + return slice_layer->getOutput(0); +} + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index ad57c476e1..a97a1c7f8c 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -101,6 +101,8 @@ nvinfer1::ITensor* scalar_to_tensor(ConversionCtx* ctx, at::Scalar s); nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b); +nvinfer1::ITensor* add_expand(ConversionCtx* ctx, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims); + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/expand.cpp b/core/conversion/converters/impl/expand.cpp index 0e68768e15..998ae8523e 100644 --- a/core/conversion/converters/impl/expand.cpp +++ b/core/conversion/converters/impl/expand.cpp @@ -27,78 +27,14 @@ nvinfer1::ITensor* concat(int max_rank, int old_rank, ConversionCtx* ctx, nvinfe } } -bool add_expand(ConversionCtx* ctx, const torch::jit::Node* n, nvinfer1::ITensor* in, nvinfer1::Dims expandedDims) { - auto input_dims = in->getDimensions(); - TORCHTRT_CHECK( - input_dims.nbDims <= expandedDims.nbDims, - "Number of dimensions of the desired expansion must be greater than or equal to the number of input dimensions"); - - // Validate the expansion. Eg: an input of [3, 1] can be expanded to [1, 3, 4] but not [3, 4, 1] - for (int64_t i = expandedDims.nbDims - 1; i >= 0; --i) { - int64_t offset = expandedDims.nbDims - 1 - i; - int64_t dim = input_dims.nbDims - 1 - offset; - int64_t size = (dim >= 0) ? input_dims.d[dim] : 1; - int64_t targetSize = expandedDims.d[i]; - // In expand layer passing -1 as the size for a dimension means not changing the size of that dimension. - if (targetSize != -1) { - if (size != targetSize) { - if (size != 1) { - TORCHTRT_THROW_ERROR( - "The expanded size of tensor (" << targetSize << ")" - << " must match the existing size (" << size << ")" - << " at dimension " << i); - } - } - } else { - // For the new dimensions, the size cannot be set to -1. Eg: an input of [3, 1] can be expanded to [3, -1, 4] but - // not [-1, 3, 4]. - if (dim < 0) { - TORCHTRT_THROW_ERROR( - "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, non-existing dimension " - << i); - } else { - // in(3, 1), expand(3, -1, 4) -> expand(3, 3, 4) - expandedDims.d[i] = input_dims.d[dim]; - } - } - } - - auto num_expand_dims = expandedDims.nbDims - input_dims.nbDims; - if (num_expand_dims > 0) { - nvinfer1::Dims reshape_dims; - reshape_dims.nbDims = expandedDims.nbDims; - for (int64_t i = 0; i < num_expand_dims; i++) { - reshape_dims.d[i] = 1; - } - for (int64_t i = 0; i < input_dims.nbDims; i++) { - reshape_dims.d[num_expand_dims + i] = input_dims.d[i]; - } - // Add a reshape layer to expand dims - auto reshape_layer = ctx->net->addShuffle(*in); - reshape_layer->setReshapeDimensions(reshape_dims); - in = reshape_layer->getOutput(0); - LOG_DEBUG("Input reshaped to : " << in->getDimensions() << " from " << input_dims); - } - - // Start the slicing from beginning of tensor since this is an expand layer - std::vector start_vec(expandedDims.nbDims, 0); - auto start_offset = util::toDims(c10::IntArrayRef(start_vec)); - - // Set the stride of non singleton dimension to 1 - std::vector strides_vec(expandedDims.nbDims, 0); - for (int64_t i = 0; i < expandedDims.nbDims; i++) { - strides_vec[i] = (in->getDimensions().d[i] != 1); - } - - auto strides = util::toDims(c10::IntArrayRef(strides_vec)); - // Slice layer does the expansion in TRT. Desired output size is specified by expandedDims - auto slice_layer = ctx->net->addSlice(*in, start_offset, expandedDims, strides); - slice_layer->setName(util::node_info(n).c_str()); - - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_layer->getOutput(0)); - +bool add_expand_static( + ConversionCtx* ctx, + const torch::jit::Node* n, + nvinfer1::ITensor* in, + nvinfer1::Dims expandedDims) { + auto expand_out = add_expand(ctx, in, expandedDims); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], expand_out); LOG_DEBUG("Expand layer output tensor shape: " << out->getDimensions()); - return true; } @@ -209,7 +145,7 @@ auto expand_registrations TORCHTRT_UNUSED = auto expandedDimsTensor = tensor_to_const(ctx, thExpanded_size); return add_expand_dynamic(ctx, n, in, expandedDimsTensor, expandedDims, true); } else { - return add_expand(ctx, n, in, expandedDims); + return add_expand_static(ctx, n, in, expandedDims); } }}) .pattern( @@ -223,7 +159,7 @@ auto expand_registrations TORCHTRT_UNUSED = if (ctx->input_is_dynamic) { return add_expand_dynamic(ctx, n, in, getShapeOutput(ctx, targetTensor), targetDims, false); } else { - return add_expand(ctx, n, in, targetDims); + return add_expand_static(ctx, n, in, targetDims); } }}) .pattern( diff --git a/core/conversion/converters/impl/layer_norm.cpp b/core/conversion/converters/impl/layer_norm.cpp index 5bc4f1a07e..4bb1c1211b 100644 --- a/core/conversion/converters/impl/layer_norm.cpp +++ b/core/conversion/converters/impl/layer_norm.cpp @@ -10,41 +10,6 @@ namespace converters { namespace impl { namespace { -nvinfer1::ITensor* broadcast( - ConversionCtx* ctx, - const torch::jit::Node* n, - nvinfer1::ITensor* to_broadcast, - const int nbDims, - const std::string& tag) { - auto to_broadcast_nbdims = to_broadcast->getDimensions().nbDims; - TORCHTRT_CHECK(to_broadcast_nbdims <= nbDims, "Cannot broadcast tensor with more dimensions than the target"); - if (to_broadcast_nbdims == nbDims) { - return to_broadcast; - } - auto shape_layer = ctx->net->addShape(*to_broadcast); - TORCHTRT_CHECK(shape_layer, "Unable to create shape layer from node: " << *n); - shape_layer->setName((util::node_info(n) + "_shape_" + tag).c_str()); - auto shape_layer_out = shape_layer->getOutput(0); - - auto extra_dims_tensor = torch::ones({nbDims - to_broadcast_nbdims}, torch::TensorOptions().dtype(torch::kInt32)); - auto extra_dims_itensor = tensor_to_const(ctx, extra_dims_tensor); - - std::vector to_concat = {extra_dims_itensor, shape_layer_out}; - auto concat_layer = ctx->net->addConcatenation(to_concat.data(), to_concat.size()); - TORCHTRT_CHECK(concat_layer, "Unable to create concat layer from node: " << *n); - concat_layer->setName((util::node_info(n) + "_concat_" + tag).c_str()); - auto target_shape = concat_layer->getOutput(0); - - auto shuffle_layer = ctx->net->addShuffle(*to_broadcast); - TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setName((util::node_info(n) + "_shuffle_" + tag).c_str()); - shuffle_layer->setInput(1, *target_shape); - auto output = shuffle_layer->getOutput(0); - LOG_DEBUG( - "Broadcast " << tag << " to shape: " << output->getDimensions() << " from " << to_broadcast->getDimensions()); - return output; -} - auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern({ R"SIG(aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor))SIG", @@ -62,20 +27,22 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() nvinfer1::ITensor* gamma = nullptr; if (args[2].IValue()->isNone()) { - auto gamma_torch_tensor = torch::ones(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32)); + auto gamma_torch_tensor = + torch::ones(input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType()))); gamma = tensor_to_const(ctx, gamma_torch_tensor); } else { gamma = args[2].ITensorOrFreeze(ctx); - gamma = broadcast(ctx, n, gamma, input_shape_vec.size(), "gamma"); + gamma = add_expand(ctx, gamma, input_shape); } nvinfer1::ITensor* beta = nullptr; if (args[3].IValue()->isNone()) { - auto beta_torch_tensor = torch::zeros(input_shape_vec, torch::TensorOptions().dtype(torch::kFloat32)); + auto beta_torch_tensor = torch::zeros( + input_shape_vec, torch::TensorOptions().dtype(util::TRTDataTypeToScalarType(input->getType()))); beta = tensor_to_const(ctx, beta_torch_tensor); } else { beta = args[3].ITensorOrFreeze(ctx); - beta = broadcast(ctx, n, beta, input_shape_vec.size(), "beta"); + beta = add_expand(ctx, beta, input_shape); } auto eps = args[4].unwrapToDouble(); @@ -84,7 +51,7 @@ auto layer_norm_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() TORCHTRT_CHECK(normalize_layer, "Unable to create layer_norm from node: " << *n); normalize_layer->setName(util::node_info(n).c_str()); normalize_layer->setEpsilon(eps); - normalize_layer->setComputePrecision(nvinfer1::DataType::kFLOAT); + normalize_layer->setComputePrecision(input->getType()); auto normalized = normalize_layer->getOutput(0); ctx->AssociateValueAndTensor(n->outputs()[0], normalized); diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 472d00abac..fd9dbe84a0 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -142,11 +142,11 @@ void LowerGraph(std::shared_ptr& g, std::vector& g) { // Change intermediate op output type LOG_GRAPH(user->schema()); - torch::jit::Node* new_node; - switch (user->kind()) { - // Use this to handle special cases where the scalar version of the intermediate operator - // has a different schema than the original - case c10::aten::add: - new_node = g->create( - user->kind(), - torch::jit::ArrayRef({user->inputs()[0], user->inputs()[1]}), - 1); - new_node->insertAfter(user); - new_node->outputs()[0]->setType(c10::IntType::get()); - user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - user->destroy(); - break; - case c10::aten::floor_divide: - new_node = g->create(c10::aten::floordiv, user->inputs(), 1); - new_node->insertAfter(user); - new_node->outputs()[0]->setType(c10::IntType::get()); - user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - user->destroy(); - break; - case c10::aten::div: - // If the first two entries to aten::div are non-Tensors, - // there cannot be a rounding mode specified (3rd entry) - if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) && - !user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) && - user->inputs().size() == 3 && - user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) && - torch::jit::toIValue(user->inputs()[2]).has_value()) { - // Select the first 2 entries of the inputs, corresponding to the values - auto div_args = user->inputs().slice(0, 2); - - // Depending on the rounding mode, create the appropriate nodes - if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") { - // Truncate case (round result towards 0) - torch::jit::Node* new_node_div; - // Create node which simply divides the two entries - new_node_div = g->create(c10::aten::div, div_args, 1); - new_node_div->insertAfter(user); - new_node_div->outputs()[0]->setType(c10::FloatType::get()); - - // Create node which casts the result to an integer, effectively truncating - new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1); - new_node->insertAfter(new_node_div); - new_node->outputs()[0]->setType(c10::IntType::get()); - - user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - user->destroy(); - break; - - } else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") { - // Floor case (round result down) - // Replace aten::div with aten::floordiv - new_node = g->create(c10::aten::floordiv, div_args, 1); - new_node->insertAfter(user); - new_node->outputs()[0]->setType(c10::IntType::get()); - - user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - user->destroy(); - break; - } + // Use this to handle special cases where the scalar version of the intermediate operator + // has a different schema than the original + if (user->kind() == c10::Symbol::fromQualString("aten::add")) { + new_node = g->create( + c10::Symbol::fromQualString("aten::add"), + torch::jit::ArrayRef({user->inputs()[0], user->inputs()[1]}), + 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + } else if (user->kind() == c10::Symbol::fromQualString("aten::floordiv")) { + new_node = g->create(c10::aten::floordiv, user->inputs(), 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + } else if (user->kind() == c10::Symbol::fromQualString("aten::div")) { + // If the first two entries to aten::div are non-Tensors, + // there cannot be a rounding mode specified (3rd entry) + if (!user->inputs()[0]->type()->isSubtypeOf(c10::TensorType::get()) && + !user->inputs()[1]->type()->isSubtypeOf(c10::TensorType::get()) && + user->inputs().size() == 3 && + user->inputs()[2]->type()->isSubtypeOf(c10::StringType::get()) && + torch::jit::toIValue(user->inputs()[2]).has_value()) { + // Select the first 2 entries of the inputs, corresponding to the values + auto div_args = user->inputs().slice(0, 2); + + // Depending on the rounding mode, create the appropriate nodes + if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "trunc") { + // Truncate case (round result towards 0) + torch::jit::Node* new_node_div; + // Create node which simply divides the two entries + new_node_div = g->create(c10::aten::div, div_args, 1); + new_node_div->insertAfter(user); + new_node_div->outputs()[0]->setType(c10::FloatType::get()); + + // Create node which casts the result to an integer, effectively truncating + new_node = g->create(c10::aten::Int, new_node_div->outputs(), 1); + new_node->insertAfter(new_node_div); + new_node->outputs()[0]->setType(c10::IntType::get()); + + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); + } else if (torch::jit::toIValue(user->inputs()[2]).value().toStringRef() == "floor") { + // Floor case (round result down) + // Replace aten::div with aten::floordiv + new_node = g->create(c10::aten::floordiv, div_args, 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); } - - default: - new_node = g->create(user->kind(), user->inputs(), 1); - new_node->insertAfter(user); - new_node->outputs()[0]->setType(c10::IntType::get()); - user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); - user->destroy(); - break; + } + } else { + new_node = g->create(user->kind(), user->inputs(), 1); + new_node->insertAfter(user); + new_node->outputs()[0]->setType(c10::IntType::get()); + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); + user->destroy(); } LOG_GRAPH("New intermediate operation: " << *new_node); diff --git a/tests/cpp/test_compiled_modules.cpp b/tests/cpp/test_compiled_modules.cpp index 7def168249..e2b2273f71 100644 --- a/tests/cpp/test_compiled_modules.cpp +++ b/tests/cpp/test_compiled_modules.cpp @@ -5,7 +5,11 @@ TEST_P(CppAPITests, CompiledModuleIsClose) { std::vector trt_inputs_ivalues; std::vector shapes; for (uint64_t i = 0; i < input_shapes.size(); i++) { - auto in = at::randint(5, input_shapes[i], {at::kCUDA}).to(input_types[i]); + auto in = at::randn(input_shapes[i], {at::kCUDA}).to(input_types[i]); + if (input_types[i] == at::kInt || input_types[i] == at::kLong) { + auto in = at::randint(0, 2, input_shapes[i], {at::kCUDA}).to(input_types[i]); + } + jit_inputs_ivalues.push_back(in.clone()); trt_inputs_ivalues.push_back(in.clone()); auto in_spec = torch_tensorrt::Input(input_shapes[i]); @@ -58,9 +62,6 @@ INSTANTIATE_TEST_SUITE_P( PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}), - PathAndInput({"tests/modules/bert_base_uncased_traced.jit.pt", {{1, 14}, {1, 14}}, {at::kInt, at::kInt}}))); -// NOTE: ViT tests are disabled until Python 3.11 issue is resolved -// https://github.com/huggingface/pytorch-image-models/issues/1946 PathAndInput({"tests/modules/vit_scripted.jit.pt", -// {{1, 3, 224, 224}}, {at::kFloat}}))); + PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}}))); #endif diff --git a/tests/modules/hub.py b/tests/modules/hub.py index 6316c8ddb4..0cce523fb3 100644 --- a/tests/modules/hub.py +++ b/tests/modules/hub.py @@ -51,11 +51,10 @@ "model": timm.create_model("efficientnet_b0", pretrained=True), "path": "script", }, - # NOTE: Disabling ViT until support in 3.11 is fixed https://github.com/huggingface/pytorch-image-models/issues/1946 - # "vit": { - # "model": timm.create_model("vit_base_patch16_224", pretrained=True), - # "path": "script", - # }, + "vit": { + "model": timm.create_model("vit_base_patch16_224", pretrained=True), + "path": "script", + }, "pooling": {"model": cm.Pool(), "path": "trace"}, "module_fallback": {"model": cm.ModuleFallbackMain(), "path": "script"}, "loop_fallback_eval": {"model": cm.LoopFallbackEval(), "path": "script"}, @@ -68,7 +67,7 @@ "tuple_input_output": {"model": cm.TupleInputOutput(), "path": "script"}, "list_input_output": {"model": cm.ListInputOutput(), "path": "script"}, "list_input_tuple_output": {"model": cm.ListInputTupleOutput(), "path": "script"}, - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, + # "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, } diff --git a/tests/py/ts/models/custom_models.py b/tests/py/ts/models/custom_models.py deleted file mode 100644 index a19b9ca81c..0000000000 --- a/tests/py/ts/models/custom_models.py +++ /dev/null @@ -1,28 +0,0 @@ -import torch -from transformers import BertModel, BertTokenizer, BertConfig - - -def BertModule(): - model_name = "bert-base-uncased" - enc = BertTokenizer.from_pretrained(model_name) - text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]" - tokenized_text = enc.tokenize(text) - masked_index = 8 - tokenized_text[masked_index] = "[MASK]" - indexed_tokens = enc.convert_tokens_to_ids(tokenized_text) - segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1] - tokens_tensor = torch.tensor([indexed_tokens]) - segments_tensors = torch.tensor([segments_ids]) - config = BertConfig( - vocab_size_or_config_json_file=32000, - hidden_size=768, - num_hidden_layers=12, - num_attention_heads=12, - intermediate_size=3072, - torchscript=True, - ) - model = BertModel(config) - model.eval() - model = BertModel.from_pretrained(model_name, torchscript=True) - traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors]) - return traced_model diff --git a/tests/py/ts/models/test_models.py b/tests/py/ts/models/test_models.py index 2195ff3708..bae94374d8 100644 --- a/tests/py/ts/models/test_models.py +++ b/tests/py/ts/models/test_models.py @@ -2,7 +2,6 @@ import unittest from typing import Dict -import custom_models as cm import timm import torch import torch_tensorrt as torchtrt @@ -92,43 +91,6 @@ def test_efficientnet_b0(self): msg=f"EfficientNet-B0 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - @unittest.skip("Layer Norm issue needs to be addressed") - def test_bert_base_uncased(self): - self.model = cm.BertModule().cuda() - self.input = torch.randint(0, 2, (1, 14), dtype=torch.int32).to("cuda") - - compile_spec = { - "inputs": [ - torchtrt.Input( - self.input.shape, - dtype=self.input.dtype, - format=torch.contiguous_format, - ), - torchtrt.Input( - self.input.shape, - dtype=self.input.dtype, - format=torch.contiguous_format, - ), - ], - "device": { - "device_type": torchtrt.DeviceType.GPU, - "gpu_id": 0, - }, - "enabled_precisions": {torch.float}, - "truncate_long_and_double": True, - } - with torchtrt.logging.errors(): - trt_mod = torchtrt.ts.compile(self.model, **compile_spec) - - model_outputs = self.model(self.input, self.input) - trt_model_outputs = trt_mod(self.input, self.input) - for out, trt_out in zip(model_outputs, trt_model_outputs): - cos_sim = cosine_similarity(out, trt_out) - self.assertTrue( - cos_sim > COSINE_THRESHOLD, - msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - ) - def test_resnet18_half(self): self.model = models.resnet18(pretrained=True).eval().to("cuda") self.input = torch.randn((1, 3, 224, 224)).to("cuda") diff --git a/tests/py/ts/models/test_multiple_registered_engines.py b/tests/py/ts/models/test_multiple_registered_engines.py index 407502f04a..3dbd724755 100644 --- a/tests/py/ts/models/test_multiple_registered_engines.py +++ b/tests/py/ts/models/test_multiple_registered_engines.py @@ -2,7 +2,6 @@ import unittest from typing import Dict -import custom_models as cm import timm import torch import torch_tensorrt as torchtrt