From dc28c324462b2d93ee24354f39c761dc900acedd Mon Sep 17 00:00:00 2001 From: Bart Chrzaszcz Date: Mon, 23 Oct 2023 04:37:45 -0700 Subject: [PATCH 001/246] Print optional attr-dict on `stablehlo.reduce`. In the compact version of printing, any custom attributes a user of StableHLO sets will not appear. This makes sure they appear if they exist. PiperOrigin-RevId: 575778539 --- third_party/stablehlo/temporary.patch | 38 +++++++++++++++++++ .../xla/third_party/stablehlo/temporary.patch | 38 +++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/third_party/stablehlo/temporary.patch b/third_party/stablehlo/temporary.patch index 85792ae3fc8702..820d2dea0fb42f 100644 --- a/third_party/stablehlo/temporary.patch +++ b/third_party/stablehlo/temporary.patch @@ -980,6 +980,25 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -1543,6 +1543,7 @@ + p << " across dimensions = ["; + llvm::interleaveComma(getDimensions().getValues(), p); + p << "]"; ++ p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); + p << " : "; + p.printFunctionalType(*this); + } else { +@@ -1705,6 +1706,7 @@ + if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || + parser.parseEqual() || + parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(reduceOpFnType) || + parser.parseOptionalLocationSpecifier(explicitLoc)) + return failure(); diff --ruN a/stablehlo/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/stablehlo/tests/infer_stablehlo.mlir --- stablehlo/stablehlo/tests/infer_stablehlo.mlir +++ stablehlo/stablehlo/tests/infer_stablehlo.mlir @@ -1078,6 +1097,25 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ }> func.func @is_compatible_sparse_mix_non_sparse(%arg0: tensor<1xf32>, %arg1: tensor<1xf32, #SV>) { +diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir +--- stablehlo/stablehlo/tests/print_reduce.mlir ++++ stablehlo/stablehlo/tests/print_reduce.mlir +@@ -168,3 +168,15 @@ + + func.return %0: tensor<4xf32> + } ++ ++// The test case makes sure any custom attrs set on the reduce-op are ++// printed/parsed when pretty-printed. ++ ++// CHECK-LABEL: func @pretty_print_with_custom_attr ++// CHECK: applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} ++ ++func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { ++ %0 = stablehlo.constant dense<0.000000e+00> : tensor ++ %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> ++ return %1 : tensor<2x13xf32> ++} diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir diff --git a/third_party/xla/third_party/stablehlo/temporary.patch b/third_party/xla/third_party/stablehlo/temporary.patch index 85792ae3fc8702..820d2dea0fb42f 100644 --- a/third_party/xla/third_party/stablehlo/temporary.patch +++ b/third_party/xla/third_party/stablehlo/temporary.patch @@ -980,6 +980,25 @@ diff --ruN a/stablehlo/stablehlo/dialect/ExperimentalOps.h b/stablehlo/stablehlo +} // namespace mlir + +#endif // STABLEHLO_DIALECT_EXPERIMENTAL_OPS_H +diff --ruN a/stablehlo/stablehlo/dialect/StablehloOps.cpp b/stablehlo/stablehlo/dialect/StablehloOps.cpp +--- stablehlo/stablehlo/dialect/StablehloOps.cpp ++++ stablehlo/stablehlo/dialect/StablehloOps.cpp +@@ -1543,6 +1543,7 @@ + p << " across dimensions = ["; + llvm::interleaveComma(getDimensions().getValues(), p); + p << "]"; ++ p.printOptionalAttrDict(getOperation()->getAttrs(), {"dimensions"}); + p << " : "; + p.printFunctionalType(*this); + } else { +@@ -1705,6 +1706,7 @@ + if (parser.parseKeyword("across") || parser.parseKeyword("dimensions") || + parser.parseEqual() || + parser.parseCommaSeparatedList(AsmParser::Delimiter::Square, parseDim) || ++ parser.parseOptionalAttrDict(result.attributes) || + parser.parseColon() || parser.parseType(reduceOpFnType) || + parser.parseOptionalLocationSpecifier(explicitLoc)) + return failure(); diff --ruN a/stablehlo/stablehlo/tests/infer_stablehlo.mlir b/stablehlo/stablehlo/tests/infer_stablehlo.mlir --- stablehlo/stablehlo/tests/infer_stablehlo.mlir +++ stablehlo/stablehlo/tests/infer_stablehlo.mlir @@ -1078,6 +1097,25 @@ diff --ruN a/stablehlo/stablehlo/tests/ops_stablehlo.mlir b/stablehlo/stablehlo/ }> func.func @is_compatible_sparse_mix_non_sparse(%arg0: tensor<1xf32>, %arg1: tensor<1xf32, #SV>) { +diff --ruN a/stablehlo/stablehlo/tests/print_reduce.mlir b/stablehlo/stablehlo/tests/print_reduce.mlir +--- stablehlo/stablehlo/tests/print_reduce.mlir ++++ stablehlo/stablehlo/tests/print_reduce.mlir +@@ -168,3 +168,15 @@ + + func.return %0: tensor<4xf32> + } ++ ++// The test case makes sure any custom attrs set on the reduce-op are ++// printed/parsed when pretty-printed. ++ ++// CHECK-LABEL: func @pretty_print_with_custom_attr ++// CHECK: applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} ++ ++func.func @pretty_print_with_custom_attr(%arg0: tensor<2x64x13xf32>) -> tensor<2x13xf32> { ++ %0 = stablehlo.constant dense<0.000000e+00> : tensor ++ %1 = stablehlo.reduce(%arg0 init: %0) applies stablehlo.add across dimensions = [1] {custom_user_attr = 1 : i64} : (tensor<2x64x13xf32>, tensor) -> tensor<2x13xf32> ++ return %1 : tensor<2x13xf32> ++} diff --ruN a/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir b/stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir --- stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir +++ stablehlo/stablehlo/tests/stablehlo_canonicalize_dynamism.mlir From a39d7df972aaedc3a2bf4a986e6e3d48f5f3eab5 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 23 Oct 2023 05:56:27 -0700 Subject: [PATCH 002/246] PR #6475: Fix flaky test bitcast_dtypes_expander_test Imported from GitHub PR https://github.com/openxla/xla/pull/6475 `bitcast_dtypes_expander_test` failed when executed in Docker container ``` //xla/service:bitcast_dtypes_expander_test FAILED in 3 out of 3 in 0.7s ``` Reason - swapped variable names in `S64toS32` test ``` %shift-right-logical.11 -> %shift-right-logical.12 %constant.12 -> %constant.11 ``` To resolve the issue we can use pattern based var names instead. ``` %[[VAL_10:.*]] %[[VAL_11:.*]] ``` Copybara import of the project: -- b99ed813b39c6b53f1aa0981dbaae060d8cb5aa8 by Alexander Pivovarov : Fix flaky test bitcast_dtypes_expander_test Merging this change closes #6475 PiperOrigin-RevId: 575793343 --- .../service/bitcast_dtypes_expander_test.cc | 103 +++++++++--------- 1 file changed, 54 insertions(+), 49 deletions(-) diff --git a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc index 00b5ec2180525c..b400b28cf8824c 100644 --- a/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc +++ b/third_party/xla/xla/service/bitcast_dtypes_expander_test.cc @@ -45,26 +45,26 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: HloModule bitcast_to_smaller // CHECK: %xla.bitcast_convert_s32_10__2_s8_10_4_.17 (a.1: s32[10]) -> s8[10,4] { -// CHECK: %a.1 = s32[10]{0} parameter(0) -// CHECK: %reshape.2 = s32[10,1]{1,0} reshape(s32[10]{0} %a.1) -// CHECK: %broadcast.3 = s32[10,1]{1,0} broadcast(s32[10,1]{1,0} %reshape.2), dimensions={0,1} -// CHECK: %reshape.4 = s32[10]{0} reshape(s32[10,1]{1,0} %broadcast.3) -// CHECK: %broadcast.5 = s32[10,4]{1,0} broadcast(s32[10]{0} %reshape.4), dimensions={0} -// CHECK: %bitcast-convert.6 = u32[10,4]{1,0} bitcast-convert(s32[10,4]{1,0} %broadcast.5) -// CHECK: %constant.8 = u32[] constant(8) -// CHECK: %broadcast.9 = u32[10,4]{1,0} broadcast(u32[] %constant.8), dimensions={} -// CHECK: %iota.7 = u32[10,4]{1,0} iota(), iota_dimension=1 -// CHECK: %multiply.10 = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %broadcast.9, u32[10,4]{1,0} %iota.7) -// CHECK: %shift-right-logical{{\.?[0-9]*}} = u32[10,4]{1,0} shift-right-logical(u32[10,4]{1,0} %bitcast-convert.6, u32[10,4]{1,0} %multiply.10) -// CHECK: %constant{{\.?[0-9]*}} = u32[] constant(255) -// CHECK: %broadcast.13 = u32[10,4]{1,0} broadcast(u32[] %constant{{\.?[0-9]*}}), dimensions={} -// CHECK: %and.14 = u32[10,4]{1,0} and(u32[10,4]{1,0} %shift-right-logical{{\.?[0-9]*}}, u32[10,4]{1,0} %broadcast.13) -// CHECK: %convert.15 = u8[10,4]{1,0} convert(u32[10,4]{1,0} %and.14) -// CHECK: ROOT %bitcast-convert.16 = s8[10,4]{1,0} bitcast-convert(u8[10,4]{1,0} %convert.15) +// CHECK: %[[VAL_0:.*]] = s32[10]{0} parameter(0) +// CHECK: %[[VAL_1:.*]] = s32[10,1]{1,0} reshape(s32[10]{0} %[[VAL_0]]) +// CHECK: %[[VAL_2:.*]] = s32[10,1]{1,0} broadcast(s32[10,1]{1,0} %[[VAL_1]]), dimensions={0,1} +// CHECK: %[[VAL_3:.*]] = s32[10]{0} reshape(s32[10,1]{1,0} %[[VAL_2]]) +// CHECK: %[[VAL_4:.*]] = s32[10,4]{1,0} broadcast(s32[10]{0} %[[VAL_3]]), dimensions={0} +// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} bitcast-convert(s32[10,4]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_6:.*]] = u32[] constant(8) +// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1 +// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-right-logical(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_11:.*]] = u32[] constant(255) +// CHECK: %[[VAL_12:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_11]]), dimensions={} +// CHECK: %[[VAL_13:.*]] = u32[10,4]{1,0} and(u32[10,4]{1,0} %[[VAL_10]], u32[10,4]{1,0} %[[VAL_12]]) +// CHECK: %[[VAL_14:.*]] = u8[10,4]{1,0} convert(u32[10,4]{1,0} %[[VAL_13]]) +// CHECK: ROOT %[[VAL_15:.*]] = s8[10,4]{1,0} bitcast-convert(u8[10,4]{1,0} %[[VAL_14]]) // CHECK: } // CHECK: ENTRY %main (p: s32[10]) -> s8[10,4] { -// CHECK: %p = s32[10]{0} parameter(0) -// CHECK: ROOT %call = s8[10,4]{1,0} call(s32[10]{0} %p), to_apply=%xla.bitcast_convert_s32_10__2_s8_10_4_.17 +// CHECK: %[[VAL_16:.*]] = s32[10]{0} parameter(0) +// CHECK: ROOT %[[VAL_17:.*]] = s8[10,4]{1,0} call(s32[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]] // CHECK: } )")); } @@ -88,26 +88,26 @@ ENTRY main { EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: HloModule bitcast_to_smaller, entry_computation_layout={(s64[10]{0})->s32[10,2]{1,0}} // CHECK: %xla.bitcast_convert_s64_10__2_s32_10_2_.17 (a.1: s64[10]) -> s32[10,2] { -// CHECK: %a.1 = s64[10]{0} parameter(0) -// CHECK: %reshape.2 = s64[10,1]{1,0} reshape(s64[10]{0} %a.1) -// CHECK: %broadcast.3 = s64[10,1]{1,0} broadcast(s64[10,1]{1,0} %reshape.2), dimensions={0,1} -// CHECK: %reshape.4 = s64[10]{0} reshape(s64[10,1]{1,0} %broadcast.3) -// CHECK: %broadcast.5 = s64[10,2]{1,0} broadcast(s64[10]{0} %reshape.4), dimensions={0} -// CHECK: %bitcast-convert.6 = u64[10,2]{1,0} bitcast-convert(s64[10,2]{1,0} %broadcast.5) -// CHECK: %constant.8 = u64[] constant(32) -// CHECK: %broadcast.9 = u64[10,2]{1,0} broadcast(u64[] %constant.8), dimensions={} -// CHECK: %iota.7 = u64[10,2]{1,0} iota(), iota_dimension=1 -// CHECK: %multiply.10 = u64[10,2]{1,0} multiply(u64[10,2]{1,0} %broadcast.9, u64[10,2]{1,0} %iota.7) -// CHECK: %shift-right-logical.11 = u64[10,2]{1,0} shift-right-logical(u64[10,2]{1,0} %bitcast-convert.6, u64[10,2]{1,0} %multiply.10) -// CHECK: %constant.12 = u64[] constant(4294967295) -// CHECK: %broadcast.13 = u64[10,2]{1,0} broadcast(u64[] %constant.12), dimensions={} -// CHECK: %and.14 = u64[10,2]{1,0} and(u64[10,2]{1,0} %shift-right-logical.11, u64[10,2]{1,0} %broadcast.13) -// CHECK: %convert.15 = u32[10,2]{1,0} convert(u64[10,2]{1,0} %and.14) -// CHECK: ROOT %bitcast-convert.16 = s32[10,2]{1,0} bitcast-convert(u32[10,2]{1,0} %convert.15) +// CHECK: %[[VAL_0:.*]] = s64[10]{0} parameter(0) +// CHECK: %[[VAL_1:.*]] = s64[10,1]{1,0} reshape(s64[10]{0} %[[VAL_0]]) +// CHECK: %[[VAL_2:.*]] = s64[10,1]{1,0} broadcast(s64[10,1]{1,0} %[[VAL_1]]), dimensions={0,1} +// CHECK: %[[VAL_3:.*]] = s64[10]{0} reshape(s64[10,1]{1,0} %[[VAL_2]]) +// CHECK: %[[VAL_4:.*]] = s64[10,2]{1,0} broadcast(s64[10]{0} %[[VAL_3]]), dimensions={0} +// CHECK: %[[VAL_5:.*]] = u64[10,2]{1,0} bitcast-convert(s64[10,2]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_6:.*]] = u64[] constant(32) +// CHECK: %[[VAL_7:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_8:.*]] = u64[10,2]{1,0} iota(), iota_dimension=1 +// CHECK: %[[VAL_9:.*]] = u64[10,2]{1,0} multiply(u64[10,2]{1,0} %[[VAL_7]], u64[10,2]{1,0} %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u64[10,2]{1,0} shift-right-logical(u64[10,2]{1,0} %[[VAL_5]], u64[10,2]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_11:.*]] = u64[] constant(4294967295) +// CHECK: %[[VAL_12:.*]] = u64[10,2]{1,0} broadcast(u64[] %[[VAL_11]]), dimensions={} +// CHECK: %[[VAL_13:.*]] = u64[10,2]{1,0} and(u64[10,2]{1,0} %[[VAL_10]], u64[10,2]{1,0} %[[VAL_12]]) +// CHECK: %[[VAL_14:.*]] = u32[10,2]{1,0} convert(u64[10,2]{1,0} %[[VAL_13]]) +// CHECK: ROOT %[[VAL_15:.*]] = s32[10,2]{1,0} bitcast-convert(u32[10,2]{1,0} %[[VAL_14]]) // CHECK: } // CHECK: ENTRY %main (p: s64[10]) -> s32[10,2] { -// CHECK: %p = s64[10]{0} parameter(0) -// CHECK: ROOT %call = s32[10,2]{1,0} call(s64[10]{0} %p), to_apply=%xla.bitcast_convert_s64_10__2_s32_10_2_.17 +// CHECK: %[[VAL_16:.*]] = s64[10]{0} parameter(0) +// CHECK: ROOT %[[VAL_17:.*]] = s32[10,2]{1,0} call(s64[10]{0} %[[VAL_16]]), to_apply=%[[VAL_18:.*]] // CHECK: } )")); } @@ -132,22 +132,27 @@ ENTRY main { EXPECT_TRUE(changed); EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( // CHECK: HloModule bitcast_to_larger +// CHECK: %or_U32.10 (lhs.11: u32[], rhs.12: u32[]) -> u32[] { +// CHECK: %[[VAL_0:.*]] = u32[] parameter(0) +// CHECK: %[[VAL_1:.*]] = u32[] parameter(1) +// CHECK: ROOT %[[VAL_2:.*]] = u32[] or(u32[] %[[VAL_0]], u32[] %[[VAL_1]]) +// CHECK: } // CHECK: %xla.bitcast_convert_s8_10_4__2_s32_10_.16 (a.1: s8[10,4]) -> s32[10] { -// CHECK: %a.1 = s8[10,4]{1,0} parameter(0) -// CHECK: %bitcast-convert.2 = u8[10,4]{1,0} bitcast-convert(s8[10,4]{1,0} %a.1) -// CHECK: %convert.3 = u32[10,4]{1,0} convert(u8[10,4]{1,0} %bitcast-convert.2) -// CHECK: %constant{{\.?[0-9]*}} = u32[] constant(8) -// CHECK: %broadcast.6 = u32[10,4]{1,0} broadcast(u32[] %constant{{\.?[0-9]*}}), dimensions={} -// CHECK: %iota{{\.?[0-9]*}} = u32[10,4]{1,0} iota(), iota_dimension=1 -// CHECK: %multiply.7 = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %broadcast.6, u32[10,4]{1,0} %iota{{\.?[0-9]*}}) -// CHECK: %shift-left.8 = u32[10,4]{1,0} shift-left(u32[10,4]{1,0} %convert.3, u32[10,4]{1,0} %multiply.7) -// CHECK: %constant.9 = u32[] constant(0) -// CHECK: %reduce.14 = u32[10]{0} reduce(u32[10,4]{1,0} %shift-left.8, u32[] %constant.9), dimensions={1}, to_apply=%or_U32.10 -// CHECK: ROOT %bitcast-convert.15 = s32[10]{0} bitcast-convert(u32[10]{0} %reduce.14) +// CHECK: %[[VAL_3:.*]] = s8[10,4]{1,0} parameter(0) +// CHECK: %[[VAL_4:.*]] = u8[10,4]{1,0} bitcast-convert(s8[10,4]{1,0} %[[VAL_3]]) +// CHECK: %[[VAL_5:.*]] = u32[10,4]{1,0} convert(u8[10,4]{1,0} %[[VAL_4]]) +// CHECK: %[[VAL_6:.*]] = u32[] constant(8) +// CHECK: %[[VAL_7:.*]] = u32[10,4]{1,0} broadcast(u32[] %[[VAL_6]]), dimensions={} +// CHECK: %[[VAL_8:.*]] = u32[10,4]{1,0} iota(), iota_dimension=1 +// CHECK: %[[VAL_9:.*]] = u32[10,4]{1,0} multiply(u32[10,4]{1,0} %[[VAL_7]], u32[10,4]{1,0} %[[VAL_8]]) +// CHECK: %[[VAL_10:.*]] = u32[10,4]{1,0} shift-left(u32[10,4]{1,0} %[[VAL_5]], u32[10,4]{1,0} %[[VAL_9]]) +// CHECK: %[[VAL_11:.*]] = u32[] constant(0) +// CHECK: %[[VAL_12:.*]] = u32[10]{0} reduce(u32[10,4]{1,0} %[[VAL_10]], u32[] %[[VAL_11]]), dimensions={1}, to_apply=%[[VAL_13:.*]] +// CHECK: ROOT %[[VAL_14:.*]] = s32[10]{0} bitcast-convert(u32[10]{0} %[[VAL_12]]) // CHECK: } // CHECK: ENTRY %main (p: s8[10,4]) -> s32[10] { -// CHECK: %p = s8[10,4]{1,0} parameter(0) -// CHECK: ROOT %call = s32[10]{0} call(s8[10,4]{1,0} %p), to_apply=%xla.bitcast_convert_s8_10_4__2_s32_10_.16 +// CHECK: %[[VAL_15:.*]] = s8[10,4]{1,0} parameter(0) +// CHECK: ROOT %[[VAL_16:.*]] = s32[10]{0} call(s8[10,4]{1,0} %[[VAL_15]]), to_apply=%[[VAL_17:.*]] // CHECK: } )")); } From 6f30fc1fd7ff45940aa2ce17042facde66394948 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 23 Oct 2023 06:04:04 -0700 Subject: [PATCH 003/246] Integrate LLVM at llvm/llvm-project@e558be51bab0 Updates LLVM usage to match [e558be51bab0](https://github.com/llvm/llvm-project/commit/e558be51bab0) PiperOrigin-RevId: 575794929 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 41c19b57550b2c..8be65a702ab288 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "49af6502c6dcb4a7f7520178bd14df396f78240c" - LLVM_SHA256 = "3900937d00f7f9a462469288d25bccb8c75544967ded15848fdda4dd1837b86a" + LLVM_COMMIT = "e558be51bab051d1471d92e967f8a2aecc13567a" + LLVM_SHA256 = "94069d8ccbab6451c7b070f26d926c310cdba17481ebc997273d95e0c82d86f8" tf_http_archive( name = name, From a02bae3e033fc325c322b12c596411b181108cbc Mon Sep 17 00:00:00 2001 From: Michael Hudgins Date: Mon, 23 Oct 2023 06:27:19 -0700 Subject: [PATCH 004/246] Add libcublas-12-2 to prevent the libnvinfer8 dependency from pulling cuda 12.3 dependencies. Otherwise, when pulled, the 12.3 will change the alias of /usr/local/cuda to 12.3, causing failures. PiperOrigin-RevId: 575799576 --- tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt index f0f05bef74639f..6cfcaf4ec078a8 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.packages.txt @@ -11,6 +11,7 @@ libcufft-12-2 libcurand-12-2 libcusolver-dev-12-2 libcusparse-dev-12-2 +libcublas-12-2 libcublas-dev-12-2 libnccl-dev=2.18.5-1+cuda12.2 # CuDNN: https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#ubuntu-network-installation From c251034d99a7e20126d7742136c93fa045dc078f Mon Sep 17 00:00:00 2001 From: George Necula Date: Mon, 23 Oct 2023 07:29:35 -0700 Subject: [PATCH 005/246] [XlaCallModule] Add more logging. Log the op attributes on `--vmodule=xla_call_module_op=3`. PiperOrigin-RevId: 575812286 --- tensorflow/compiler/tf2xla/kernels/BUILD | 1 + .../tf2xla/kernels/xla_call_module_op.cc | 35 ++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index e8d7cd052b5fba..e1605b5c3d3bdf 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -369,6 +369,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc index 200b24951ed68b..416eb5fa75bd2c 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" @@ -150,7 +151,30 @@ class XlaCallModuleOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("disabled_checks", &disabled_checks)); std::vector platforms; OP_REQUIRES_OK(ctx, ctx->GetAttr("platforms", &platforms)); + // TODO(necula): change this to OP_REQUIRES_OK when 6 months have passed + // since we added the function_list and has_token_input_output + // attributes (May 25, 2023). + if (!ctx->GetAttr("has_token_input_output", &module_has_token_input_output_) + .ok()) { + module_has_token_input_output_ = false; + } + if (!ctx->GetAttr("function_list", &function_list_).ok()) { + function_list_.clear(); + } + if (VLOG_IS_ON(3)) { + VLOG(3) << "Initializing XlaCallModuleOp (version = " << version + << ", platforms = [" << absl::StrJoin(platforms, ", ") + << "], has_token_input_output = " + << module_has_token_input_output_ << ", disabled_checks = [" + << absl::StrJoin(disabled_checks, ", ") << "], " + << "function_list = [" + << absl::StrJoin(function_list_, ",", + [](std::string *out, NameAttrList x) { + absl::StrAppend(out, x.name()); + }) + << "])"; + } string loading_device_type = ctx->device_type().type_string(); string loading_platform = ""; if (loading_device_type == DEVICE_CPU_XLA_JIT) { @@ -171,11 +195,8 @@ class XlaCallModuleOp : public XlaOpKernel { absl::UnimplementedError(absl::StrCat( "Unexpected device type ", loading_device_type))); } - VLOG(3) << "Initialized XlaCallModuleOp on " << loading_platform; - if (!ctx->GetAttr("has_token_input_output", &module_has_token_input_output_) - .ok()) { - module_has_token_input_output_ = false; - } + VLOG(3) << "Initializing XlaCallModuleOp on " << loading_platform; + { auto loader = XlaCallModuleLoader::Create( &context_, version, std::move(module_str), std::move(disabled_checks), @@ -187,10 +208,6 @@ class XlaCallModuleOp : public XlaOpKernel { } OP_REQUIRES_OK(ctx, loader_->ValidateDialect()); - if (!ctx->GetAttr("function_list", &function_list_).ok()) { - function_list_.clear(); - } - if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) { token_input_nodes_.clear(); op_has_token_input_output_ = false; From 3441b35bb02d0cfdb77a890dc272326aa6b735ea Mon Sep 17 00:00:00 2001 From: Matthias Kramm Date: Mon, 23 Oct 2023 07:32:48 -0700 Subject: [PATCH 006/246] Introduce tf.GeneratorDatasetRegion op. PiperOrigin-RevId: 575813013 --- tensorflow/compiler/mlir/tensorflow/BUILD | 5 ++ .../compiler/mlir/tensorflow/ir/tf_ops.td | 53 ++++++++++- .../compiler/mlir/tensorflow/ir/tf_ops_a_m.cc | 89 +++++++++++++++++-- .../compiler/mlir/tensorflow/ir/tf_ops_n_z.cc | 28 +++--- .../mlir/tensorflow/tests/tf-ops.mlir | 40 ++++++++- .../compiler/mlir/tf2xla/transforms/BUILD | 7 +- .../transforms/legalization_op_config_test.cc | 11 +-- 7 files changed, 202 insertions(+), 31 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 78d87e81c45525..01f225c9abd595 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -353,11 +353,14 @@ cc_library( ":attribute_utils", ":convert_type", ":dynamic_shape_utils", + ":tensorflow_all_ops_inc_gen", ":tensorflow_attributes", ":tensorflow_op_interfaces", ":tensorflow_op_interfaces_inc_gen", + ":tensorflow_remaining_ops_inc_gen", ":tensorflow_side_effects", ":tensorflow_structs", + ":tensorflow_tfrt_ops_inc_gen", ":tensorflow_traits", ":tensorflow_types", ":tf_arith_ops_folder", @@ -369,6 +372,8 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_canonicalize_inc_gen", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:ControlFlowInterfaces", "@llvm-project//mlir:DerivedAttributeOpInterface", diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 5e4d546c4d47bc..ddedaf63ade266 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -312,7 +312,7 @@ def TF_YieldOp : TF_Op<"Yield", [Terminator, Pure, NativeOpTrait<"ReturnLike", [], "", "">, - ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp"]>, + ParentOneOf<["CaseRegionOp", "IfRegionOp", "WhileRegionOp", "GeneratorDatasetRegionOp"]>, DeclareOpInterfaceMethods, ]> { @@ -389,6 +389,57 @@ else_branch: A region that computes the outputs of the op if cond = false. let hasCanonicalizer = 1; } +def TF_GeneratorDatasetRegionOp : TF_Op<"GeneratorDatasetRegion", + [AttrSizedOperandSegments, + DeclareOpInterfaceMethods, + SingleBlockImplicitTerminator<"YieldOp">, + TF_GeneratorOpSideEffect, + ]> { + let summary = "Regional version of GeneratorDataset"; + + let description = [{ +Creates a dataset that invokes its 'next' region to generate elements. Conceptually, +within MLIR, we treat this op as if it fills a buffer with all the results right away, +and those results are then passed (through the variant tensor result) to +MakeIterator / IteratorGetNext. Note that the actual TF implementation differs: It +generates the next element just in time, during IteratorGetNext. + +init_extra_args: Additional arguments to pass to 'init'. +next_extra_args: Additional arguments to pass to 'next'. (Passed after the + normal arguments which are from the return values of 'init'.) +finalize_extra_args: Additional arguments to pass to 'finalize'. (Passed after + the normal arguments which are from the return values of 'init'.) + }]; + + let arguments = (ins + Variadic:$init_func_other_args, + Variadic:$next_func_other_args, + Variadic:$finalize_func_other_args, + + ConfinedAttr]>:$output_types, + ConfinedAttr]>:$output_shapes, + DefaultValuedOptionalAttr:$metadata + ); + + let results = (outs + TF_VariantTensor:$handle + ); + + let regions = (region SizedRegion<1>:$init, + SizedRegion<1>:$next, + SizedRegion<1>:$finalize + ); + + TF_DerivedOperandTypeListAttr Tinit_func_args = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedOperandTypeListAttr Tnext_func_args = TF_DerivedOperandTypeListAttr<1>; + TF_DerivedOperandTypeListAttr Tfinalize_func_args = TF_DerivedOperandTypeListAttr<2>; +} + def TF_LegacyCallOp : TF_Op<"LegacyCall", [CallOpInterface, DeclareOpInterfaceMethods, Pure]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc index 4f21118dc24b71..14ffb82412c3d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.cc @@ -19,55 +19,60 @@ limitations under the License. #include #include #include +#include #include -#include #include -#include -#include #include #include #include #include +#include "absl/log/check.h" +#include "absl/strings/str_cat.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/iterator_range.h" -#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_device_helper.h" @@ -75,12 +80,14 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/core/framework/kernel_shape_util.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -2968,6 +2975,70 @@ StringRef FusedBatchNormV3Op::GetOptimalLayout(const RuntimeDevices& devices) { return ::mlir::TF::GetOptimalLayout(devices, this); } +//===----------------------------------------------------------------------===// +// GeneratorDatasetRegionOp +//===----------------------------------------------------------------------===// + +bool GeneratorDatasetRegionOp::areTypesCompatible(Type t1, Type t2) { + return true; // Don't enforce type checking across control-flow edges. +} + +void GeneratorDatasetRegionOp::getRegionInvocationBounds( + ArrayRef operands, + SmallVectorImpl& invocationBounds) { + // We invoke `init` once, `finalize` once, and `next` any number of times. + invocationBounds.emplace_back(InvocationBounds(1, 1)); // init + invocationBounds.emplace_back(InvocationBounds::getUnknown()); // next + invocationBounds.emplace_back(InvocationBounds(1, 1)); // finalize +} + +OperandRange GeneratorDatasetRegionOp::getEntrySuccessorOperands( + RegionBranchPoint point) { + auto end = this->getOperation()->operand_end(); + if (point.isParent()) { + // The op itself doesn't branch back to itself. + return ::mlir::OperandRange(end, end); + } else if (point.getRegionOrNull() == &getInit()) { + return getInitFuncOtherArgs(); + } else if (point.getRegionOrNull() == &getNext()) { + return getNextFuncOtherArgs(); + } else /* finalize region */ { + return getFinalizeFuncOtherArgs(); + } +} + +void GeneratorDatasetRegionOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl& regions) { + int n; + if (point.isParent()) { + // The op itself branches to `init` first. + regions.push_back( + RegionSuccessor(&getInit(), getInit().front().getArguments())); + } else if (point.getRegionOrNull() == &getInit()) { + // `init` branches to `next`, passing along the arguments given to `init`'s + // yield. Said arguments precede the "other args". + n = getInitFuncOtherArgs().size(); + regions.push_back(RegionSuccessor( + &getNext(), getNext().front().getArguments().drop_back(n))); + } else if (point.getRegionOrNull() == &getNext()) { + // `next` branches to itself, or to `finalize`, passing all arguments given + // to `next`s yield. + + // The number of values we're passing along. + int num = getNext().front().getTerminator()->getNumOperands(); + + // The number of extra values from the parent ops that should go to `next` + // and `finalize`. + regions.push_back(RegionSuccessor( + &getNext(), getNext().front().getArguments().slice(0, num))); + regions.push_back(RegionSuccessor( + &getFinalize(), getFinalize().front().getArguments().slice(0, num))); + } else { + // `finalize` branches back to the op itself, not passing any arguments. + regions.push_back(RegionSuccessor()); + } +} + //===----------------------------------------------------------------------===// // GatherV2Op //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 4ac541799afd4f..01cbbb9a46967c 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -18,11 +18,12 @@ limitations under the License. #include #include +#include +#include +#include #include -#include #include #include -#include #include #include #include @@ -34,16 +35,15 @@ limitations under the License. #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Support/Casting.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/MathExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project @@ -53,26 +53,28 @@ limitations under the License. #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project -#include "mlir/IR/DialectImplementation.h" // from @llvm-project #include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project -#include "mlir/IR/OpImplementation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project #include "mlir/IR/TypeRange.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project #include "mlir/Interfaces/ControlFlowInterfaces.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_arith_ops_folder.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_canonicalization_helper.h" @@ -80,14 +82,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_layout_helper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_tensor_helper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_side_effects.h" -#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/rewrite_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/util/tensor_format.h" namespace mlir { namespace TF { @@ -4389,6 +4389,12 @@ MutableOperandRange YieldOp::getMutableSuccessorOperands( this->getOperation(), 1, this->getOperation()->getOperands().size() - 1); } + } else if (auto regionOp = llvm::dyn_cast( + this->getOperation()->getParentOp())) { + if (®ionOp.getFinalize() == this->getOperation()->getParentRegion()) { + // `finalize`'s returns get discarded. + return MutableOperandRange(this->getOperation(), 0, 0); + } } return MutableOperandRange(this->getOperation()); } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 22de9ee736426a..7605f0360625fe 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1185,7 +1185,7 @@ func.func @testInvalidIfOp(tensor, tensor<*xf32>) -> tensor<2xf32> { // Test invalid tf.Yield operation (parent should be IfRegion) func.func @testInvalidYieldOp(%arg0: f32) -> () { - // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion'}} + // expected-error @+1 {{'tf.Yield' op expects parent op to be one of 'tf.CaseRegion, tf.IfRegion, tf.WhileRegion, tf.GeneratorDatasetRegion'}} "tf.Yield"(%arg0) : (f32) -> () } @@ -5180,3 +5180,41 @@ func.func @test_xla_call_module_with_invalid_symbol() { "tf.XlaCallModule"() {Sout = [], device = "", dim_args_spec = [], function_list = [@undefined_function], module = "", platforms = [], version = 4 : i64} : () -> () func.return } + +// ----- + +func.func @init(%arg0: tensor<4xf32>) -> tensor<7xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<7xf32> + return %0 : tensor<7xf32> +} + +func.func @next(%arg0: tensor<7xf32>, %arg1: tensor<3xf32>) -> tensor<6xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<6xf32> + return %0 : tensor<6xf32> +} + +func.func @finalize(%arg0: tensor<6xf32>, %arg1: tensor<2xf32>) -> tensor<5xf32> { + %0 = builtin.unrealized_conversion_cast to tensor<5xf32> + return %0 : tensor<5xf32> +} + +// CHECK-LABEL: func @testGeneratorDataset +func.func @testGeneratorDataset(%arg0: tensor<4xf32>, + %arg1: tensor<3xf32>, + %arg2: tensor, + %arg3: tensor<2xf32>) -> tensor { + %0 = "tf.GeneratorDataset"(%arg0, %arg1, %arg2, %arg3) { + device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0", + finalize_func = @finalize, + init_func = @init, + next_func = @next, + operandSegmentSizes = array, + output_shapes = [#tf_type.shape<>], + output_types = [!tf_type.string], + metadata = ""} : ( + tensor<4xf32>, + tensor<3xf32>, + tensor, + tensor<2xf32>) -> tensor + return %0 : tensor +} diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index b4f5d00b7a9e37..ed0429ad242c94 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -1,10 +1,10 @@ # Description: # TF2XLA Bridge transforms -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static") package( @@ -487,13 +487,12 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/status", "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", "@local_tsl//tsl/lib/core:status_test_util", - "@local_tsl//tsl/platform:errors", "@local_tsl//tsl/platform:status", "@local_tsl//tsl/platform:statusor", ], diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 24f7711539a5b4..4063a85477f8b2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -20,21 +20,22 @@ limitations under the License. #include #include -#include #include #include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/register_common_dialects.h" -#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/core/tpu/tpu_defs.h" #include "tsl/lib/core/status_test_util.h" -#include "tsl/platform/errors.h" #include "tsl/platform/status.h" #include "tsl/platform/statusor.h" @@ -131,7 +132,7 @@ TEST_F(LegalizationOpConfigTest, CountLoweringsSet) { // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); EXPECT_EQ(tf2xla_fallback_count, 315); - EXPECT_EQ(non_categorized_count, 420); + EXPECT_EQ(non_categorized_count, 421); } // Just a counter test to see which ops have duplicate lowerings. This isn't a From 1b57ad592d0afc7e51e56a4b32f60bff20e653c7 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 07:34:00 -0700 Subject: [PATCH 007/246] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8. PiperOrigin-RevId: 575813261 --- third_party/tf_runtime/workspace.bzl | 4 ++-- third_party/xla/third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index af9d9dcce05374..a5384b613a5232 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "f6b5570b2e04978d6362a0f307982c56fb0e01cd" - TFRT_SHA256 = "d16308e600db822a7b2e21921febc5d34cc3fcf0197739e2c258b32083cc5da4" + TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" + TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tf_runtime/workspace.bzl index af9d9dcce05374..a5384b613a5232 100644 --- a/third_party/xla/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "f6b5570b2e04978d6362a0f307982c56fb0e01cd" - TFRT_SHA256 = "d16308e600db822a7b2e21921febc5d34cc3fcf0197739e2c258b32083cc5da4" + TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" + TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index af9d9dcce05374..a5384b613a5232 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "f6b5570b2e04978d6362a0f307982c56fb0e01cd" - TFRT_SHA256 = "d16308e600db822a7b2e21921febc5d34cc3fcf0197739e2c258b32083cc5da4" + TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" + TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" tf_http_archive( name = "tf_runtime", From ba095590a538bd9daed0a7186531f77de7ef0443 Mon Sep 17 00:00:00 2001 From: Ilia Sergachev Date: Mon, 23 Oct 2023 07:34:45 -0700 Subject: [PATCH 008/246] [XLA:GPU] Avoid scheduling already scheduled modules. PiperOrigin-RevId: 575813414 --- third_party/xla/xla/service/gpu/BUILD | 20 ++++++++ .../xla/xla/service/gpu/gpu_hlo_schedule.cc | 21 +++++++++ .../xla/service/gpu/gpu_hlo_schedule_test.cc | 46 +++++++++++++++++++ 3 files changed, 87 insertions(+) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8e44ef4fdbc6a6..d3be0ea927744f 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3173,8 +3173,13 @@ cc_library( deps = [ ":backend_configs_cc", ":cublas_cudnn", + "//xla:shape_util", + "//xla:status", + "//xla:statusor", + "//xla:util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:buffer_value", "//xla/service:hlo_memory_scheduler", "//xla/service:hlo_pass_pipeline", "//xla/service:latency_hiding_scheduler", @@ -3182,6 +3187,11 @@ cc_library( "//xla/service:profile_guided_latency_estimator", "//xla/service/gpu/model:analytical_latency_estimator", "//xla/stream_executor:device_description", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", @@ -3199,13 +3209,23 @@ xla_cc_test( tags = tf_cuda_tests_tags(), deps = [ ":gpu_hlo_schedule", + "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/utils:hlo_query", + "//xla/service:backend", "//xla/service:gpu_plugin", + "//xla/service:hlo_module_config", + "//xla/service:hlo_ordering", "//xla/stream_executor:device_description", + "//xla/tests:filecheck", "//xla/tests:hlo_test_base", "//xla/tests:test_utils", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/log", + "@com_google_googletest//:gtest", + "@local_tsl//tsl/platform:status", + "@local_tsl//tsl/platform:statusor", "@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc", ], ) diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc index dd01ed3bfdce96..8d8a7ca04d85f1 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule.cc @@ -15,6 +15,8 @@ limitations under the License. #include "xla/service/gpu/gpu_hlo_schedule.h" +#include +#include #include #include #include @@ -22,14 +24,25 @@ limitations under the License. #include #include +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/check.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_split.h" #include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_input_output_alias_config.h" +#include "xla/hlo/ir/hlo_instruction.h" #include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/buffer_value.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/model/analytical_latency_estimator.h" @@ -38,7 +51,12 @@ limitations under the License. #include "xla/service/latency_hiding_scheduler.h" #include "xla/service/p2p_schedule_preparation.h" #include "xla/service/profile_guided_latency_estimator.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status.h" +#include "xla/statusor.h" #include "xla/stream_executor/device_description.h" +#include "xla/util.h" #include "tsl/platform/env.h" #include "tsl/platform/errors.h" #include "tsl/platform/protobuf.h" @@ -600,6 +618,9 @@ int64_t GetSizeOfShape(const Shape& shape, int pointer_size) { Status ScheduleGpuModule(HloModule* module, int64_t pointer_size, int64_t memory_limit, const se::DeviceDescription& gpu_device_info) { + if (module->has_schedule()) { + return OkStatus(); + } HloPassPipeline prepare_pipeline("p2p-schedule-preparation"); prepare_pipeline.AddPass(); TF_RETURN_IF_ERROR(prepare_pipeline.Run(module).status()); diff --git a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc index 8fc62dbebc5fd0..e952c0c9d8b07e 100644 --- a/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_hlo_schedule_test.cc @@ -16,20 +16,34 @@ limitations under the License. #include "xla/service/gpu/gpu_hlo_schedule.h" #include +#include +#include #include #include #include #include #include +#include +#include "absl/algorithm/container.h" +#include "absl/log/log.h" #include "xla/hlo/ir/hlo_computation.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/utils/hlo_query.h" +#include "xla/service/backend.h" +#include "xla/service/hlo_module_config.h" +#include "xla/service/hlo_ordering.h" +#include "xla/shape.h" +#include "xla/shape_util.h" #include "xla/stream_executor/device_description.h" +#include "xla/tests/filecheck.h" #include "xla/tests/hlo_test_base.h" #include "xla/tests/test_utils.h" +#include "tsl/platform/status.h" +#include "tsl/platform/statusor.h" #include "tsl/profiler/protobuf/profiled_instructions.pb.h" namespace xla { @@ -891,6 +905,38 @@ while_body { get_index("recv.1", while_body)); } +TEST_F(GpuHloScheduleTest, SkipAlreadyScheduled) { + auto module = ParseAndReturnVerifiedModule(R"( +HloModule m, is_scheduled=true + +fused_computation { + param_0 = f32[1024,1024]{1,0} parameter(0) + ROOT exponential.1 = f32[1024,1024]{1,0} exponential(param_0) +} + +fused_computation.1 { + param_0.1 = f32[1024,1024]{1,0} parameter(0) + ROOT negate.1 = f32[1024,1024]{1,0} negate(param_0.1) +} + +ENTRY e { + p = f32[1024,1024]{1,0} parameter(0) + wrapped_negate = f32[1024,1024]{1,0} fusion(p), kind=kLoop, calls=fused_computation.1 + wrapped_exponential = f32[1024,1024]{1,0} fusion(p), kind=kLoop, calls=fused_computation + ROOT t = (f32[1024,1024]{1,0}, f32[1024,1024]{1,0}) tuple(wrapped_exponential, wrapped_negate) +})") + .value(); + TF_CHECK_OK(ScheduleGpuModule( + module.get(), /*pointer_size=*/8, + /*memory_limit=*/1024 * 1024 * 1024, + backend().default_stream_executor()->GetDeviceDescription())); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( +// CHECK: ENTRY +// CHECK: wrapped_negate = f32[1024,1024]{1,0} +// CHECK: wrapped_exponential = f32[1024,1024]{1,0} +)")); +} + class GpuHloScheduleParameterizedTest : public GpuHloScheduleTest, public ::testing::WithParamInterface {}; From 68e88d2a0f72ad623e4155a3d9683ff2bc5f6e3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tam=C3=A1s=20Danyluk?= Date: Mon, 23 Oct 2023 07:52:53 -0700 Subject: [PATCH 009/246] [XLA:GPU] Handle NaNs correctly in minimum and maximum Fixed that IrEmitterTriton "swallowed" NaNs in minimum and maximum. Fixed that ElementalIrEmitter "swallowed" NaNs in the case of minimum(x, NaN) on GPU. That was likely caused by an llvm error. PiperOrigin-RevId: 575817645 --- third_party/xla/xla/service/BUILD | 2 + .../xla/service/elemental_ir_emitter_test.cc | 142 +++++++++++++ .../xla/xla/service/gpu/ir_emitter_triton.cc | 38 +++- .../xla/service/gpu/ir_emitter_triton_test.cc | 201 ++++++++++++++++++ .../xla/xla/service/llvm_ir/llvm_util.cc | 31 ++- 5 files changed, 404 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/BUILD b/third_party/xla/xla/service/BUILD index 3e69749aa7fecc..5ffe582bedda79 100644 --- a/third_party/xla/xla/service/BUILD +++ b/third_party/xla/xla/service/BUILD @@ -5411,6 +5411,7 @@ xla_test( ], deps = [ ":hlo_parser", + "//xla:error_spec", "//xla:execution_options_util", "//xla:status_macros", "//xla:test", @@ -5418,6 +5419,7 @@ xla_test( "//xla/tests:hlo_test_base", "//xla/tests:test_macros_header", "//xla/tests:xla_internal_test_main", + "@com_google_absl//absl/strings", ], ) diff --git a/third_party/xla/xla/service/elemental_ir_emitter_test.cc b/third_party/xla/xla/service/elemental_ir_emitter_test.cc index 9ca6f0a918ba6b..9823d9101eb5f4 100644 --- a/third_party/xla/xla/service/elemental_ir_emitter_test.cc +++ b/third_party/xla/xla/service/elemental_ir_emitter_test.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/string_view.h" +#include "xla/error_spec.h" #include "xla/execution_options_util.h" #include "xla/service/hlo_parser.h" #include "xla/status_macros.h" @@ -48,6 +50,18 @@ class ElementalIrEmitterExecutionTest : public HloTestBase { } }; +class ElementalIrEmitterExecutionTestWithoutFastMinMax + : public ElementalIrEmitterExecutionTest { + protected: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = + ElementalIrEmitterExecutionTest::GetDebugOptionsForTest(); + debug_options.set_xla_cpu_enable_fast_min_max(false); + debug_options.set_xla_gpu_enable_fast_min_max(false); + return debug_options; + } +}; + XLA_TEST_F(ElementalIrEmitterExecutionTest, DotFusion) { const std::string hlo_text = R"( HloModule FusedDot @@ -669,5 +683,133 @@ XLA_TEST_F(ElementalIrEmitterExecutionTest, IotaF8E5FNUZ) { RunTest(hlo_text, {}); } +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MinimumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + ROOT min = f32[5,5] minimum(nans, neg1s) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MinimumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + ROOT min = f32[5,5] minimum(neg1s, nans) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MaximumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + ROOT max = f32[5,5] maximum(nans, neg1s) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MaximumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + ROOT max = f32[5,5] maximum(neg1s, nans) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MinimumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + ROOT min = f32[5,5] minimum(zeros, ones) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MinimumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + ROOT min = f32[5,5] minimum(ones, zeros) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MaximumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + ROOT max = f32[5,5] maximum(ones, zeros) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +XLA_TEST_F(ElementalIrEmitterExecutionTestWithoutFastMinMax, + MaximumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + ROOT max = f32[5,5] maximum(zeros, ones) +})"; + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + } // namespace } // namespace xla diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc index ca90128c63a7d5..fb9fcf4bddc956 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton.cc @@ -316,13 +316,43 @@ Value Compare(ImplicitLocOpBuilder& b, ValueRange values, } Value Maximum(ImplicitLocOpBuilder& b, ValueRange values) { - auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::GT); - return b.create(cmp, values[0], values[1]); + // ma::MaximumFOp seems to think that max(NaN, x) = x, so we don't use that. + // + // logic: isNaN(lhs) || (!isNan(rhs) && lhs > rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This also works, but we wanted to make it similar to minimum. + // logic: isNaN(lhs) || lhs > rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_greater = + Compare(b, values, mlir::mhlo::ComparisonDirection::GT); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_greater)), + values[0], values[1]); } Value Minimum(ImplicitLocOpBuilder& b, ValueRange values) { - auto cmp = Compare(b, values, mlir::mhlo::ComparisonDirection::LT); - return b.create(cmp, values[0], values[1]); + // ma::MinimumFOp seems to think that min(NaN, x) = x, so we don't use that. + // + // logic: isNaN(lhs) || (!isNan(rhs) && lhs < rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This should also work, but the tests show that it doesn't work for + // minimum(x, NaN): + // logic: isNaN(lhs) || lhs < rhs ? lhs : rhs + Value lhs_is_nan = + Compare(b, {values[0], values[0]}, mlir::mhlo::ComparisonDirection::NE); + Value rhs_is_not_nan = + Compare(b, {values[1], values[1]}, mlir::mhlo::ComparisonDirection::EQ); + Value lhs_is_less = Compare(b, values, mlir::mhlo::ComparisonDirection::LT); + return b.create( + b.create(lhs_is_nan, + b.create(rhs_is_not_nan, lhs_is_less)), + values[0], values[1]); } // TODO(b/269489810): Contribute nicer builders to Triton, so we don't need to diff --git a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc index 63ba010975f8ff..bdc15e99556df0 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_triton_test.cc @@ -995,6 +995,15 @@ class TritonGemmLevel2Test : public TritonGemmTest { } }; +class TritonGemmLevel2TestAny : public TritonGemmLevel2Test { + public: + DebugOptions GetDebugOptionsForTest() override { + DebugOptions debug_options = TritonGemmLevel2Test::GetDebugOptionsForTest(); + debug_options.set_xla_gpu_triton_gemm_any(true); + return debug_options; + } +}; + TEST_F(TritonGemmLevel2Test, BinaryOperationWithSmallInputsIsFused) { const std::string kHloText = R"( HloModule m @@ -1271,6 +1280,198 @@ ENTRY e { EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/2e-3, /*arel=*/2e-3})); } +TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + min = f32[5,5] minimum(nans, neg1s) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MinimumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + min = f32[5,5] minimum(neg1s, nans) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheLeft) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + max = f32[5,5] maximum(nans, neg1s) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MaximumHandlesNaNsOnTheRight) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + neg1 = f32[] constant(-1) + neg1s = f32[5,5] broadcast(neg1), dimensions={} + nans = f32[5,5] sqrt(neg1s) + max = f32[5,5] maximum(neg1s, nans) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MinimumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + min = f32[5,5] minimum(zeros, ones) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MinimumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + min = f32[5,5] minimum(ones, zeros) + ROOT _ = f32[5,5] dot(p0, min), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MaximumReturnsLHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + max = f32[5,5] maximum(ones, zeros) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + +TEST_F(TritonGemmLevel2TestAny, MaximumReturnsRHS) { + constexpr absl::string_view kHloText = R"( +HloModule t + +ENTRY e { + p0 = f32[5,5] parameter(0) + zero = f32[] constant(0) + zeros = f32[5,5] broadcast(zero), dimensions={} + one = f32[] constant(1) + ones = f32[5,5] broadcast(one), dimensions={} + max = f32[5,5] maximum(zeros, ones) + ROOT _ = f32[5,5] dot(p0, max), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +})"; + + MatchOptimizedHlo(kHloText, R"( +; CHECK: fusion +; CHECK-SAME: kind=kCustom +; CHECK-SAME: block_m +)"); + + EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-3, + /*arel=*/1e-3})); +} + TEST_F(TritonGemmTest, SineOutputIsNotFused) { const std::string kHloText = R"( HloModule m diff --git a/third_party/xla/xla/service/llvm_ir/llvm_util.cc b/third_party/xla/xla/service/llvm_ir/llvm_util.cc index 6fddf8ec7a8bbb..3dbf2bf48850da 100644 --- a/third_party/xla/xla/service/llvm_ir/llvm_util.cc +++ b/third_party/xla/xla/service/llvm_ir/llvm_util.cc @@ -143,10 +143,19 @@ llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = b->CreateFCmpUGE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value, name.data()); } else { - auto cmp_ge = b->CreateFCmpOGE(lhs_value, rhs_value); + // logic: isNaN(lhs) || (!isNan(rhs) && lhs > rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This also works, but we wanted to make it similar to minimum. + // logic: isNaN(lhs) || lhs > rhs ? lhs : rhs + // + // b->CreateMaximum() doesn't work on GPU before SM80. auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = b->CreateOr(cmp_ge, lhs_is_nan); - return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data()); + auto rhs_is_not_nan = b->CreateFCmpOEQ(rhs_value, rhs_value); + auto lhs_is_greater = b->CreateFCmpOGT(lhs_value, rhs_value); + return b->CreateSelect( + b->CreateOr(lhs_is_nan, b->CreateAnd(rhs_is_not_nan, lhs_is_greater)), + lhs_value, rhs_value, name.data()); } } @@ -157,10 +166,20 @@ llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value, auto cmp = b->CreateFCmpULE(lhs_value, rhs_value); return b->CreateSelect(cmp, lhs_value, rhs_value, name.data()); } else { - auto cmp_le = b->CreateFCmpOLE(lhs_value, rhs_value); + // logic: isNaN(lhs) || (!isNan(rhs) && lhs < rhs) ? lhs : rhs + // See also: IEEE Std 754-2008 5.11. + // + // This should also work, but the tests show that it doesn't work for + // minimum(x, NaN) on GPU: + // logic: isNaN(lhs) || lhs < rhs ? lhs : rhs + // + // b->CreateMaximum() doesn't work on GPU before SM80. auto lhs_is_nan = b->CreateFCmpUNE(lhs_value, lhs_value); - auto sel_lhs = b->CreateOr(cmp_le, lhs_is_nan); - return b->CreateSelect(sel_lhs, lhs_value, rhs_value, name.data()); + auto rhs_is_not_nan = b->CreateFCmpOEQ(rhs_value, rhs_value); + auto lhs_is_less = b->CreateFCmpOLT(lhs_value, rhs_value); + return b->CreateSelect( + b->CreateOr(lhs_is_nan, b->CreateAnd(rhs_is_not_nan, lhs_is_less)), + lhs_value, rhs_value, name.data()); } } From c3a5443f5b3afd94fa2d1aa1b7da894756681f3d Mon Sep 17 00:00:00 2001 From: Fergus Henderson Date: Mon, 23 Oct 2023 08:25:36 -0700 Subject: [PATCH 010/246] Add forwarding shim for the C++ XNNPack plugin. PiperOrigin-RevId: 575825788 --- .../lite/acceleration/configuration/BUILD | 46 +++++++--- .../configuration/xnnpack_plugin_test.cc | 18 +--- .../core/acceleration/configuration/BUILD | 36 +++++++- .../configuration/xnnpack_plugin.cc | 0 .../configuration/xnnpack_plugin_test.cc | 91 +++++++++++++++++++ 5 files changed, 165 insertions(+), 26 deletions(-) rename tensorflow/lite/{ => core}/acceleration/configuration/xnnpack_plugin.cc (100%) create mode 100644 tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc diff --git a/tensorflow/lite/acceleration/configuration/BUILD b/tensorflow/lite/acceleration/configuration/BUILD index b005381b7dbbd4..4f1b2fe568cb97 100644 --- a/tensorflow/lite/acceleration/configuration/BUILD +++ b/tensorflow/lite/acceleration/configuration/BUILD @@ -17,7 +17,7 @@ load("@flatbuffers//:build_defs.bzl", "DEFAULT_FLATC_ARGS", "flatbuffer_android_ load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:build_def.bzl", "tflite_copts", "tflite_copts_warnings") load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite") -load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite") +load("//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite", "cc_test_with_tflite") # copybara:uncomment load("//tools/build_defs/proto/cpp:cc_proto_library.bzl", "cc_proto_library") load(":build_defs.bzl", "flatbuffer_schema_compat_test") @@ -324,29 +324,53 @@ cc_test( ], ) -cc_library( +cc_library_with_tflite( name = "xnnpack_plugin", - srcs = ["xnnpack_plugin.cc"], compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:public"], + deps = ["//tensorflow/lite/core/acceleration/configuration:xnnpack_plugin"], +) + +cc_test_with_tflite( + name = "xnnpack_plugin_with_tflite_test", + srcs = ["xnnpack_plugin_test.cc"], + # The variant of this test that links against TF Lite in Play services + # isn't portable to iOS / Mac or Android, because it relies on a separate + # shared library that isn't included in the executable, and the testing + # infrastructure for iOS and Android doesn't propagate data dependencies + # to the test device. So we disable this test on those devices. + # TODO(b/306161304): ideally we ought to apply these tags only to the + # variant for TF Lite in Play services. In the mean time, we apply those + # tags to the whole test, but also duplicate the test below using cc_test + # without the tags. + tags = [ + "no_mac", + "tflite_not_portable_android", + "tflite_not_portable_ios", + ], + tflite_deps = [ + ":xnnpack_plugin", + "//tensorflow/lite:test_util", + "//tensorflow/lite/acceleration/configuration:delegate_registry", + ], deps = [ ":configuration_fbs", - "//tensorflow/lite:minimal_logging", - "//tensorflow/lite/core/acceleration/configuration:delegate_registry", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", - "@com_google_absl//absl/base:log_severity", - "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + "@pthreadpool", ], - alwayslink = 1, # For registration to always run. ) +# This duplicates xnnnpack_plugin_with_tflite_test above, but without the tags, +# to ensure that this test does get run on iOS and Android. cc_test( name = "xnnpack_plugin_test", srcs = ["xnnpack_plugin_test.cc"], deps = [ ":configuration_fbs", ":xnnpack_plugin", - "//tensorflow/lite/core/acceleration/configuration:delegate_registry", - "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "//tensorflow/lite:test_util", + "//tensorflow/lite/acceleration/configuration:delegate_registry", "@com_google_googletest//:gtest_main", "@flatbuffers//:runtime_cc", "@pthreadpool", diff --git a/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc index 2aa1d95a44f10d..3138e7f2c584be 100644 --- a/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc +++ b/tensorflow/lite/acceleration/configuration/xnnpack_plugin_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include "flatbuffers/flatbuffers.h" // from @flatbuffers #include "pthreadpool.h" // from @pthreadpool #include "tensorflow/lite/acceleration/configuration/configuration_generated.h" -#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" -#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" +#include "tensorflow/lite/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/test_util.h" namespace tflite { -class XnnpackPluginTest : public testing::Test { +class XnnpackPluginTest : public tflite::testing::Test { public: static constexpr int kNumThreadsForTest = 7; static constexpr tflite::XNNPackFlags kFlagsForTest = @@ -70,22 +70,14 @@ class XnnpackPluginTest : public testing::Test { constexpr int XnnpackPluginTest::kNumThreadsForTest; TEST_F(XnnpackPluginTest, CanCreateAndDestroyDelegate) { - delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + delegates::TfLiteOpaqueDelegatePtr delegate = delegate_plugin_->Create(); EXPECT_NE(delegate, nullptr); } TEST_F(XnnpackPluginTest, CanGetDelegateErrno) { - delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + delegates::TfLiteOpaqueDelegatePtr delegate = delegate_plugin_->Create(); int error_number = delegate_plugin_->GetDelegateErrno(delegate.get()); EXPECT_EQ(error_number, 0); } -TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) { - delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); - pthreadpool_t threadpool = static_cast( - TfLiteXNNPackDelegateGetThreadPool(delegate.get())); - int thread_count = pthreadpool_get_threads_count(threadpool); - EXPECT_EQ(thread_count, kNumThreadsForTest); -} - } // namespace tflite diff --git a/tensorflow/lite/core/acceleration/configuration/BUILD b/tensorflow/lite/core/acceleration/configuration/BUILD index a1c1b2d45a6313..1f461966c8af37 100644 --- a/tensorflow/lite/core/acceleration/configuration/BUILD +++ b/tensorflow/lite/core/acceleration/configuration/BUILD @@ -1,7 +1,7 @@ -load("//tensorflow/lite/core:special_rules.bzl", "delegate_registry_visibility_allowlist") -load("//tensorflow/lite/core/c:special_rules.bzl", "experimental_acceleration_api_allowlist") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/lite:special_rules.bzl", "nnapi_plugin_impl_visibility_allowlist") +load("//tensorflow/lite/core:special_rules.bzl", "delegate_registry_visibility_allowlist") +load("//tensorflow/lite/core/c:special_rules.bzl", "experimental_acceleration_api_allowlist") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -88,3 +88,35 @@ cc_test( "@com_google_googletest//:gtest_main", ], ) + +cc_library( + name = "xnnpack_plugin", + srcs = ["xnnpack_plugin.cc"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow/lite:__subpackages__", + ], + deps = [ + "//tensorflow/lite:minimal_logging", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/core/acceleration/configuration:delegate_registry", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "@com_google_absl//absl/base:log_severity", + "@com_google_absl//absl/memory", + ], + alwayslink = 1, # For registration to always run. +) + +cc_test( + name = "xnnpack_plugin_test", + srcs = ["xnnpack_plugin_test.cc"], + deps = [ + ":xnnpack_plugin", + "//tensorflow/lite/acceleration/configuration:configuration_fbs", + "//tensorflow/lite/core/acceleration/configuration:delegate_registry", + "//tensorflow/lite/delegates/xnnpack:xnnpack_delegate", + "@com_google_googletest//:gtest_main", + "@flatbuffers//:runtime_cc", + "@pthreadpool", + ], +) diff --git a/tensorflow/lite/acceleration/configuration/xnnpack_plugin.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc similarity index 100% rename from tensorflow/lite/acceleration/configuration/xnnpack_plugin.cc rename to tensorflow/lite/core/acceleration/configuration/xnnpack_plugin.cc diff --git a/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc new file mode 100644 index 00000000000000..2aa1d95a44f10d --- /dev/null +++ b/tensorflow/lite/core/acceleration/configuration/xnnpack_plugin_test.cc @@ -0,0 +1,91 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Some very simple unit tests of the (C++) XNNPack Delegate Plugin. + +#include +#include +#include "flatbuffers/flatbuffers.h" // from @flatbuffers +#include "pthreadpool.h" // from @pthreadpool +#include "tensorflow/lite/acceleration/configuration/configuration_generated.h" +#include "tensorflow/lite/core/acceleration/configuration/delegate_registry.h" +#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h" + +namespace tflite { + +class XnnpackPluginTest : public testing::Test { + public: + static constexpr int kNumThreadsForTest = 7; + static constexpr tflite::XNNPackFlags kFlagsForTest = + tflite::XNNPackFlags::XNNPackFlags_TFLITE_XNNPACK_DELEGATE_FLAG_QS8_QU8; + void SetUp() override { + // Construct a FlatBuffer that contains + // TFLiteSettings { + // delegate: Delegate.XNNPACK, + // XNNPackSettings { num_threads: kNumThreadsForTest + // flags: TFLITE_XNNPACK_DELEGATE_FLAG_QS8 | + // TFLITE_XNNPACK_DELEGATE_FLAG_QU8 + // } + // }. + XNNPackSettingsBuilder xnnpack_settings_builder(flatbuffer_builder_); + xnnpack_settings_builder.add_num_threads(kNumThreadsForTest); + xnnpack_settings_builder.add_flags(kFlagsForTest); + flatbuffers::Offset xnnpack_settings = + xnnpack_settings_builder.Finish(); + TFLiteSettingsBuilder tflite_settings_builder(flatbuffer_builder_); + tflite_settings_builder.add_xnnpack_settings(xnnpack_settings); + tflite_settings_builder.add_delegate(Delegate_XNNPACK); + flatbuffers::Offset tflite_settings = + tflite_settings_builder.Finish(); + flatbuffer_builder_.Finish(tflite_settings); + tflite_settings_ = flatbuffers::GetRoot( + flatbuffer_builder_.GetBufferPointer()); + // Create an XNNPack delegate plugin using the settings from the flatbuffer. + delegate_plugin_ = delegates::DelegatePluginRegistry::CreateByName( + "XNNPackPlugin", *tflite_settings_); + ASSERT_NE(delegate_plugin_, nullptr); + } + void TearDown() override { delegate_plugin_.reset(); } + ~XnnpackPluginTest() override {} + + protected: + // settings_ points into storage owned by flatbuffer_builder_. + flatbuffers::FlatBufferBuilder flatbuffer_builder_; + const TFLiteSettings *tflite_settings_; + std::unique_ptr delegate_plugin_; +}; + +constexpr int XnnpackPluginTest::kNumThreadsForTest; + +TEST_F(XnnpackPluginTest, CanCreateAndDestroyDelegate) { + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + EXPECT_NE(delegate, nullptr); +} + +TEST_F(XnnpackPluginTest, CanGetDelegateErrno) { + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + int error_number = delegate_plugin_->GetDelegateErrno(delegate.get()); + EXPECT_EQ(error_number, 0); +} + +TEST_F(XnnpackPluginTest, SetsCorrectThreadCount) { + delegates::TfLiteDelegatePtr delegate = delegate_plugin_->Create(); + pthreadpool_t threadpool = static_cast( + TfLiteXNNPackDelegateGetThreadPool(delegate.get())); + int thread_count = pthreadpool_get_threads_count(threadpool); + EXPECT_EQ(thread_count, kNumThreadsForTest); +} + +} // namespace tflite From 729644ec7665b3cc5b1d86bb2eb464583a25189a Mon Sep 17 00:00:00 2001 From: Arturo Schmidt Date: Mon, 23 Oct 2023 09:42:50 -0700 Subject: [PATCH 011/246] Add boilerplate code to add verify clustering passes. PiperOrigin-RevId: 575845487 --- .../mlir/tf2xla/internal/passes/BUILD | 50 +++++++++++++++++++ .../internal/passes/clustering_passes.h | 35 +++++++++++++ .../internal/passes/clustering_passes.td | 26 ++++++++++ .../internal/passes/verify_clustering_pass.cc | 44 ++++++++++++++++ 4 files changed, 155 insertions(+) create mode 100644 tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD create mode 100644 tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h create mode 100644 tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td create mode 100644 tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD new file mode 100644 index 00000000000000..6c5dfba52090f0 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -0,0 +1,50 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__"], +) + +cc_library( + name = "clustering_passes", + srcs = [ + "clustering_passes.h", + ], + textual_hdrs = [ + "clustering_passes.h.inc", + ], + deps = [ + ":clustering_passes_inc_gen", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:framework", + "//tensorflow/core/transforms/toposort:Pass", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + ], +) + +gentbl_cc_library( + name = "clustering_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=TensorFlow", + ], + "clustering_passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "clustering_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", + ], +) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h new file mode 100644 index 00000000000000..09643d47cfff4f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h @@ -0,0 +1,35 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +// Verifies that all MLIR Ops have the expected attributes. +std::unique_ptr> +CreateVerifyClusteringPass(); + +#define GEN_PASS_DECL_VERIFYCLUSTERINGPASS + +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_INTERNAL_PASSES_CLUSTERING_PASSES_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td new file mode 100644 index 00000000000000..c1431369c6e0f3 --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.td @@ -0,0 +1,26 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "mlir/Pass/PassBase.td" + +def VerifyClusteringPass : Pass<"verify-clustering-pass", "mlir::func::FuncOp"> { + + let summary = "Verify that the Bridge output is correct and errors if verification fails."; + + let description = [{ + Verifies whether clustering has resulted in the expected invariants. These + include verifying that clusters have been created and have been outside + compiled, the result is device agnostic and in TF functional dialect & + that the device attribute exists. + }]; + + let constructor = "tensorflow::tf2xla::internal::CreateVerifyClusteringPass()"; +} diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc new file mode 100644 index 00000000000000..496c9fd2511e1f --- /dev/null +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass.cc @@ -0,0 +1,44 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace tensorflow { +namespace tf2xla { +namespace internal { + +namespace { + +#define GEN_PASS_DEF_VERIFYCLUSTERINGPASS +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h.inc" + +class VerifyClusteringPass + : public impl::VerifyClusteringPassPassBase { + public: + void runOnOperation() override; +}; + +void VerifyClusteringPass::runOnOperation() {} +} // namespace + +std::unique_ptr> CreateVerifyClusteringPass() { + return std::make_unique(); +} +} // namespace internal +} // namespace tf2xla +} // namespace tensorflow From 319d697f87b4c5c6a9fc7cb3740c896f82a2a55a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 09:51:45 -0700 Subject: [PATCH 012/246] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/18a68a86884cc3b033586e725cb65557420fe3ea. PiperOrigin-RevId: 575847883 --- third_party/tf_runtime/workspace.bzl | 4 ++-- third_party/xla/third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index a5384b613a5232..0b3107e4c868ad 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" - TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" + TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" + TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tf_runtime/workspace.bzl index a5384b613a5232..0b3107e4c868ad 100644 --- a/third_party/xla/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" - TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" + TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" + TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index a5384b613a5232..0b3107e4c868ad 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "4e2f93c2c98c3ee32fd62ef7dbd5cc75c80011b8" - TFRT_SHA256 = "78165bdbbc70ed583d889ce10e8aefe915774e09ae4267b6154ca788fb94d758" + TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" + TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" tf_http_archive( name = "tf_runtime", From 4f36d36a1c14565b6221c8c90bb3e46ee26db7d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 09:58:55 -0700 Subject: [PATCH 013/246] Touch up the requirements' updater README. PiperOrigin-RevId: 575850088 --- ci/official/requirements_updater/README.md | 145 ++++++++++++++------- 1 file changed, 99 insertions(+), 46 deletions(-) diff --git a/ci/official/requirements_updater/README.md b/ci/official/requirements_updater/README.md index 292cb072a367aa..ad2e350dec8dc6 100644 --- a/ci/official/requirements_updater/README.md +++ b/ci/official/requirements_updater/README.md @@ -1,75 +1,128 @@ -### Hermetic Python +# Hermetic Python -Hermetic Python allows us not to rely on system-installed python and -system-installed python packages, instead we register our own python toolchain. +Hermetic Python allows not to rely on system-installed Python, and +system-installed Python packages. \ +Instead, an independent Python toolchain is registered, ensuring the right +dependencies are always used. \ See https://github.com/bazelbuild/rules_python/ for more details. -#### Hermetic Python toolchain details +### Specifying the Python version -By default, Python 3.9 is used. +Note: Only a number of minor Python versions are supported at any given time. -To set your own version for hermetic Python toolchain, use `TF_PYTHON_VERSION` -environment variable, e.g. +By default, the lowest supported version is used. + +To set a different version, use the `TF_PYTHON_VERSION` environment variable, +e.g. ``` -export TF_PYTHON_VERSION=3.10 +export TF_PYTHON_VERSION=3.11 ``` -To set a version from argument line, add to your command +To specify the version via a Bazel command argument, use the following: ``` ---repo_env=TF_PYTHON_VERSION=3.10 +--repo_env=TF_PYTHON_VERSION=3.11 ``` -### Requirements updater - -Requirements updater is a standalone tool intended to simplify process of -updating requirements for multiple versions of Python. +## Requirements updater -#### How to update/add requirements +Requirements updater is a standalone tool, intended to simplify process of +updating requirements for multiple minor versions of Python. -By default, the name of the input requirements file is `requirements.in`, -but it can be set using the `REQUIREMENTS_FILE_NAME` variable, for example: -``` -export REQUIREMENTS_FILE_NAME=`my_requirements.in` -``` +It takes in a file with a set of dependencies, and produces a more detailed +requirements file for each version, with hashes specified for each +dependency required, as well as their sub-dependencies. -To set a version from the argument line, add to your command -``` ---repo_env=REQUIREMENTS_FILE_NAME=`my_requirements.in` -``` +### How to update/add requirements -#### How to run the updater +By default, the name of the base requirements file is `requirements.in`, but it +can be set using the `REQUIREMENTS_FILE_NAME` variable. \ +For example: ``` -bash updater.sh +export REQUIREMENTS_FILE_NAME=my_requirements.in ``` -### How to add a new Python version - -1) In the `WORKSPACE` file add a new version to `python_versions` argument of -the `python_register_multi_toolchains` function. - -2) In `BUILD.bazel` file add a load statement for the new version, e.g. +To specify the file via a Bazel command argument, use the following: ``` -load("@python//3.11:defs.bzl", - compile_pip_requirements_3_11 = "compile_pip_requirements") +--repo_env=REQUIREMENTS_FILE_NAME=my_requirements.in ``` -Add a new entry for the loaded `compile_pip_requirements`, e.g. +### How to run the updater ``` -compile_pip_requirements_3_11( - name = "requirements_3_11", - extra_args = ["--allow-unsafe"], - requirements_in = "requirements.in", - requirements_txt = "requirements_lock_3_11.txt", -) +bash updater.sh ``` -3) Add the version to `SUPPORTED_VERSIONS` in `updater.sh`, after that run the - requirements updater tool. - -4) As a result, a new `requirements_lock_3_11.txt` file should appear under the -root of tensorflow directory. +## How to add a new Python version + +Note: Updating the +[rules-python](https://github.com/bazelbuild/rules_python/releases) version may +be required before going through the steps below. This is due to the new Python +versions becoming available through `rules-python`. \ +See +[here](https://github.com/tensorflow/tensorflow/commit/f91457f258fdd78f693044a57efa63a38335d1de), +and +[here](https://github.com/tensorflow/tensorflow/commit/052445e04ce20fd747657e0198a1bcec2b6dff5b), +for an example. + +See +[this commit](https://github.com/tensorflow/tensorflow/commit/5f7f05a80aac9b01325a78ec3fcff0dbedb1cc23) +as a rough example of the steps below. + +All the files referenced below are located in the same directory as this README, +unless indicated otherwise. + +1) Add the new version to the `VERSIONS` variable inside + `tensorflow/tools/toolchains/python/python_repo.bzl`. \ + While this isn't necessary for running the updater, it is required for + actually using the new version with Tensorflow. + +2) In the `WORKSPACE` file, add the new version to the `python_versions` + parameter of the `python_register_multi_toolchains` function. + +3) In the `BUILD.bazel` file, add a load statement for the new version, e.g. + + ``` + load("@python//3.11:defs.bzl", + compile_pip_requirements_3_11 = "compile_pip_requirements") + ``` + + Add a new entry for the loaded `compile_pip_requirements`, e.g. + + ``` + compile_pip_requirements_3_11( + name = "requirements_3_11", + extra_args = ["--allow-unsafe"], + requirements_in = "requirements.in", + requirements_txt = "requirements_lock_3_11.txt", + ) + ``` + + ``` + compile_pip_requirements_3_11( + name = "requirements_3_11_release", + extra_args = [ + "--allow-unsafe", + "-P keras-nightly", + "-P tb-nightly", + "-P tf-estimator-nightly", + ], + requirements_in = "requirements.in", + requirements_txt = "requirements_lock_3_11.txt", + ) + ``` + +4) Add the version to `SUPPORTED_VERSIONS` in `updater.sh`, and + `release_updater.sh` + +5) Run the `updater.sh` shell script. \ + If the base requirements file hasn't yet been updated to account for the new + Python version, which will require different versions for at least some + dependencies, it will need to be updated now, for the script to run + successfully. + +6) A new `requirements_lock_3_11.txt` file should appear under the root of the + `tensorflow` directory. From ca428728015daa000ce890e3584e51d502140076 Mon Sep 17 00:00:00 2001 From: Anlun Xu Date: Mon, 23 Oct 2023 10:20:17 -0700 Subject: [PATCH 014/246] [xla:gpu] Allow users to toggle command types that are enabled in command buffers Before this CL the libraries supported by GPU graphs are determined a integer level. This CL allows the user to have fine-grained control on whether a library should be enabled in graphs. Example usage: XLA_FLAGS=--xla_gpu_command_buffer_command_types=FUSION,CUBLAS PiperOrigin-RevId: 575856977 --- third_party/xla/xla/BUILD | 1 + third_party/xla/xla/debug_options_flags.cc | 62 ++++++++++++++++--- .../xla/mlir/backends/gpu/transforms/BUILD | 1 + .../gpu/transforms/outline_cuda_graphs.cc | 21 ++++--- .../mlir/backends/gpu/transforms/passes.cc | 2 +- .../xla/mlir/backends/gpu/transforms/passes.h | 7 ++- third_party/xla/xla/service/gpu/BUILD | 2 +- .../xla/service/gpu/autotuner_compile_util.cc | 2 +- .../service/gpu/compile_module_to_llvm_ir.cc | 13 +++- third_party/xla/xla/xla.proto | 24 ++++--- 10 files changed, 104 insertions(+), 31 deletions(-) diff --git a/third_party/xla/xla/BUILD b/third_party/xla/xla/BUILD index 79dc02a7966200..384e20b3422e5d 100644 --- a/third_party/xla/xla/BUILD +++ b/third_party/xla/xla/BUILD @@ -1076,6 +1076,7 @@ cc_library( "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:protobuf", "@local_tsl//tsl/util:command_line_flags", ], ) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 44001445359b3b..1d7dfa37efdcb2 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -25,11 +25,16 @@ limitations under the License. #include "absl/base/call_once.h" #include "absl/container/flat_hash_map.h" #include "absl/container/node_hash_map.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" #include "xla/debug_options_parsers.h" #include "xla/parse_flags_from_env.h" #include "xla/xla.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep #include "tsl/util/command_line_flags.h" namespace xla { @@ -93,9 +98,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // flag. opts.set_xla_gpu_enable_cublaslt(false); - // TODO(b/258036887): Create separate flags for enabling cuBLAS, cuDNN, and - // NCCL in GPU graphs. - opts.set_xla_gpu_graph_level(1); + opts.add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); opts.set_xla_gpu_graph_num_runs_to_instantiate(-1); opts.set_xla_gpu_enable_persistent_temp_buffers(false); opts.set_xla_gpu_graph_min_graph_size(5); @@ -355,6 +358,46 @@ void MakeDebugOptionsFlags(std::vector* flag_list, return true; }; + // Custom "sub-parser" lambda for xla_gpu_graph_level. + auto setter_for_xla_gpu_graph_level = [debug_options](const int32_t level) { + debug_options->clear_xla_gpu_enable_command_buffer(); + if (level >= 1) { + debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::FUSION); + } + if (level >= 2) { + debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::CUBLAS); + } + if (level >= 3) { + debug_options->add_xla_gpu_enable_command_buffer(DebugOptions::CUDNN); + } + return true; + }; + + auto command_types_to_string = + [](tsl::protobuf::RepeatedField command_types) -> std::string { + struct Formatter { + void operator()(std::string* out, int type) const { + absl::StrAppend(out, DebugOptions::CommandBufferCmdType_Name(type)); + } + }; + return absl::StrJoin(command_types, ", ", Formatter()); + }; + + // Custom "sub-parser" lambda for xla_gpu_enable_command_buffer. + auto setter_for_xla_gpu_enable_command_buffer = + [debug_options](const std::string& values) { + debug_options->clear_xla_gpu_enable_command_buffer(); + for (const absl::string_view value : absl::StrSplit(values, ',')) { + DebugOptions::CommandBufferCmdType cmd_type; + if (!DebugOptions::CommandBufferCmdType_Parse( + absl::AsciiStrToUpper(value), &cmd_type)) { + return false; + } + debug_options->add_xla_gpu_enable_command_buffer(cmd_type); + } + return true; + }; + // Custom "sub-parser" for xla_fuel. Note that ConsumeFuel does not do any // locking on the fuel global variables. This means that it's // illegal/undefined behavior to modify this flag value while the compiler is @@ -943,11 +986,14 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_gpu_enable_cublaslt(), "Use cuBLASLt for GEMMs when possible.")); flag_list->push_back(tsl::Flag( - "xla_gpu_graph_level", - int32_setter_for(&DebugOptions::set_xla_gpu_graph_level), - debug_options->xla_gpu_graph_level(), - "Set GPU graph level. 0 = off; 1 = capture fusions and memcpys; 2 = " - "capture gemms; 3 = capture convolutions.")); + "xla_gpu_graph_level", setter_for_xla_gpu_graph_level, 1, + "The legacy flag for setting GPU graph level. Use " + "xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture " + "fusions and memcpys; 2 = capture gemms; 3 = capture convolutions.")); + flag_list->push_back(tsl::Flag( + "xla_gpu_enable_command_buffer", setter_for_xla_gpu_enable_command_buffer, + command_types_to_string(debug_options->xla_gpu_enable_command_buffer()), + "The types of the commands that are recorded into command buffers")); flag_list->push_back(tsl::Flag( "xla_gpu_graph_num_runs_to_instantiate", int32_setter_for( diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD b/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD index 758b54e4ef28f4..fd1c250a2d416e 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/BUILD @@ -77,6 +77,7 @@ cc_library( "//xla/stream_executor:blas", "//xla/stream_executor:device_description", "//xla/translate/mhlo_to_hlo:location_exporter", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc index a32511c5b1796f..fe004d703d8c22 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/outline_cuda_graphs.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -57,8 +58,10 @@ class OutlineGpuGraphsPass : public impl::OutlineGpuGraphsPassBase { public: OutlineGpuGraphsPass() = default; - explicit OutlineGpuGraphsPass(int gpu_graph_level, int min_graph_size) - : gpu_graph_level_(gpu_graph_level) { + explicit OutlineGpuGraphsPass( + absl::flat_hash_set command_types, + int min_graph_size) + : command_types_(std::move(command_types)) { this->min_graph_size_ = min_graph_size; } @@ -69,6 +72,8 @@ class OutlineGpuGraphsPass } private: + absl::flat_hash_set command_types_ = { + DebugOptions::FUSION, DebugOptions::CUBLAS, DebugOptions::CUDNN}; int gpu_graph_level_ = 3; }; @@ -458,7 +463,7 @@ void OutlineGpuGraphsPass::runOnOperation() { OpCapturePatternSet patterns; - if (gpu_graph_level_ >= 1) { + if (command_types_.contains(DebugOptions::FUSION)) { // Enable capturing fusions and memcpies. patterns.emplace_back(new LaunchFuncOpCapture()); patterns.emplace_back(new ConstantOpCapture()); @@ -467,12 +472,12 @@ void OutlineGpuGraphsPass::runOnOperation() { patterns.emplace_back(new ReinterpretCastOpCapture()); } - if (gpu_graph_level_ >= 2) { + if (command_types_.contains(DebugOptions::CUBLAS)) { // Enable capturing gemms. patterns.emplace_back(new GemmOpCapture()); } - if (gpu_graph_level_ >= 3) { + if (command_types_.contains(DebugOptions::CUDNN)) { // Enable capturing convolutions. patterns.emplace_back(new ConvForwardOpCapture()); patterns.emplace_back(new ConvBackwardInputOpCapture()); @@ -494,9 +499,9 @@ std::unique_ptr> createOutlineGpuGraphsPass() { } std::unique_ptr> createOutlineGpuGraphsPass( - int gpu_graph_level, int min_graph_size) { - return std::make_unique(gpu_graph_level, - min_graph_size); + absl::flat_hash_set command_types, + int min_graph_size) { + return std::make_unique(command_types, min_graph_size); } } // namespace gpu diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc index 9f6ea047c7dc0e..a131fcefad745a 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.cc @@ -39,7 +39,7 @@ void populateXlaGpuRuntimePasses(mlir::OpPassManager& pm, // Outline CUDA-Graph-compatible operations into graph capture functions. pm.addPass( - createOutlineGpuGraphsPass(opts.gpu_graph_level, opts.min_graph_size)); + createOutlineGpuGraphsPass(opts.command_types, opts.min_graph_size)); if (opts.enable_concurrent_region) { // Concurrent regions create repeated-fork-join topology inside CUDA graphs, // which is not optimized by architectures prior to Ampere and may cause diff --git a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h index 4c3fc987ea74b5..9abd1e9319d866 100644 --- a/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h +++ b/third_party/xla/xla/mlir/backends/gpu/transforms/passes.h @@ -19,10 +19,12 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_set.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "xla/stream_executor/device_description.h" +#include "xla/xla.pb.h" namespace xla { namespace gpu { @@ -44,7 +46,7 @@ struct GpuPipelineOpts { // Enable experimental pass that outlines parts of the XLA computation into // CUDA Graphs, which allows us to amortize the cost of launching multiple // device kernels. - int32_t gpu_graph_level = 0; + absl::flat_hash_set command_types; int32_t min_graph_size = 0; bool enable_concurrent_region = false; stream_executor::GpuComputeCapability compute_capability; @@ -106,7 +108,8 @@ std::unique_ptr> createOutlineGpuGraphsPass(); std::unique_ptr> createOutlineGpuGraphsPass( - int32_t gpu_graph_level, int32_t min_graph_size); + absl::flat_hash_set command_types, + int32_t min_graph_size); //===----------------------------------------------------------------------===// // Passes for marking concurrent region in CUDA graph capture function. diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index d3be0ea927744f..bc03ccb1365c18 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -2464,7 +2464,6 @@ cc_library( ":buffer_sharing", ":executable_proto_cc", ":gpu_constants", - ":gpu_convert_async_collectives_to_sync", ":gpu_executable", ":ir_emitter_context", ":ir_emitter_unnested", @@ -2493,6 +2492,7 @@ cc_library( "//xla/translate/mhlo_to_hlo:location_exporter", "//xla/translate/mhlo_to_lhlo_with_xla", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@llvm-project//llvm:AsmParser", "@llvm-project//llvm:Support", diff --git a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc index ed91d4952c9a11..69659373cf8115 100644 --- a/third_party/xla/xla/service/gpu/autotuner_compile_util.cc +++ b/third_party/xla/xla/service/gpu/autotuner_compile_util.cc @@ -88,7 +88,7 @@ AutotunerCompileUtil::AutotunerCompileUtil(const AutotuneConfig& config, // Avoid using another thread pool. opts_.set_xla_gpu_force_compilation_parallelism(1); // Avoid using GPU graphs as we don't want to measure graph construction time. - opts_.set_xla_gpu_graph_level(0); + opts_.clear_xla_gpu_enable_command_buffer(); // Disable experimental XLA:GPU runtime. opts_.set_xla_gpu_enable_gpu2_runtime(false); opts_.set_xla_embed_ir_in_executable(false); diff --git a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc index 7a08cfa5410407..c4387ff9ff3974 100644 --- a/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc +++ b/third_party/xla/xla/service/gpu/compile_module_to_llvm_ir.cc @@ -28,6 +28,7 @@ limitations under the License. #include #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/strings/str_cat.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/AsmParser/Parser.h" @@ -122,8 +123,18 @@ static Status LowerToXlaGpuRuntime( mlir::PassManager pm(module->getName(), mlir::PassManager::Nesting::Implicit); pm.enableVerifier(should_verify); + absl::flat_hash_set command_types; + for (int command_type_num : debug_options.xla_gpu_enable_command_buffer()) { + if (!DebugOptions::CommandBufferCmdType_IsValid(command_type_num)) { + return InternalError("Invalid command buffer command type"); + } + DebugOptions::CommandBufferCmdType command_type = + static_cast(command_type_num); + command_types.insert(command_type); + } + GpuPipelineOpts opts; - opts.gpu_graph_level = debug_options.xla_gpu_graph_level(); + opts.command_types = command_types; opts.min_graph_size = debug_options.xla_gpu_graph_min_graph_size(); opts.enable_concurrent_region = debug_options.xla_gpu_graph_enable_concurrent_region(); diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 739960065d5125..f95650ab1b2470 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -458,13 +458,18 @@ message DebugOptions { // Whether to use cuBLASLt for GEMMs on GPUs. bool xla_gpu_enable_cublaslt = 166; - // 0: Disable GPU graph capture. - // 1: Enable GPU graphs for fusions and memcpy (safest ones). - // 2: Enable GPU graphs for gemms. - // 3: Enable GPU graphs for convolutions. - // - // Default: 0. - int32 xla_gpu_graph_level = 194; + // Commands are categorized into four types: FUSION represents regular fusion + // kernels. CUBLAS, CUDNN, and NCCL represent library calls. + enum CommandBufferCmdType { + INVALID = 0; + FUSION = 1; + CUBLAS = 2; + CUDNN = 3; + NCCL = 4; + } + + // Determine the types of commands that are recorded into command buffers. + repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258; // Only instantiates a GPU graph after the captured function execution count // reaches the threshold. This constant is a heuristic to avoid creating a @@ -646,7 +651,7 @@ message DebugOptions { int32 xla_gpu_llvm_verification_level = 256; - // Next id: 258 + // Next id: 259 // Extra options to pass to the compilation backend (e.g. LLVM); specific // interpretation of these values is left to the backend. @@ -659,7 +664,8 @@ message DebugOptions { // xla_gpu_enable_cuda_graphs // xla_gpu_allow_all_reduce_kernel // xla_gpu_enable_experimental_block_size - reserved 5, 117, 133, 139, 176, 178, 180, 193, 214; + // xla_gpu_graph_level + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194; } // Contains flags which affects the GPU compilation result. From 1ce130020c89d6a8c8910dad06579ec28521ab92 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 10:43:58 -0700 Subject: [PATCH 015/246] Refactor Requantize in ConvertMHLOQuantToInt pass Consolidate implementations of requantize and use the same one as TF quantizer. PiperOrigin-RevId: 575864178 --- .../bridge/convert_mhlo_quant_to_int.cc | 304 +++++++----------- .../convert_tf_quant_to_mhlo_int_test.cc | 86 ++++- .../bridge/convert-mhlo-quant-to-int.mlir | 67 ++-- 3 files changed, 220 insertions(+), 237 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc index 505554bdad4bb0..ee80a6cd18b7b1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_mhlo_quant_to_int.cc @@ -15,11 +15,8 @@ limitations under the License. #include #include -#include #include #include -#include -#include #include #include "absl/algorithm/container.h" @@ -31,8 +28,6 @@ limitations under the License. #include "mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project -#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project -#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -49,7 +44,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" -#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/math_utils.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/mhlo/transforms/rewriters.h" @@ -60,70 +54,76 @@ namespace { #define GEN_PASS_DEF_CONVERTMHLOQUANTTOINT #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h.inc" -// This helper function create ops to requantize `input` tensor and output to -// `res_int32` tensor. Clamping is omitted because for some ops clamping can be -// done later to avoid duplicate. -LogicalResult RequantizeWithoutClamping( - mlir::OpState op, Value input, TensorType int32_tensor_type, - quant::UniformQuantizedType input_quantized_type, - quant::UniformQuantizedType result_quantized_type, Value &res_int32, - ConversionPatternRewriter &rewriter) { +// This helper function create ops to requantize `input` tensor and returns the +// output tensor. Clamping is done if output integer bit-width < 32. +// +// Requantization is essentially dequantize --> quantize. +// +// Dequantize: (input - zp) * scale +// Quantize: input / scale + zp +// +// Hence, +// output = (input - input_zp) * input_scale / output_scale + output_zp +// +// This is simplified as: +// output = input * merged_scale + merged_zp +// where: +// merged_zp = output_zp - input_zp * merged_scale. +// merged_scale = input_scale / output_scale. +Value Requantize(mlir::OpState op, Value input, + UniformQuantizedType input_quantized_type, + UniformQuantizedType output_quantized_type, + TensorType output_tensor_type, + ConversionPatternRewriter &rewriter) { // Skip requantization when input and result have the same type. - if (input_quantized_type == result_quantized_type) { - res_int32 = rewriter.create(op->getLoc(), - int32_tensor_type, input); - return success(); + if (input_quantized_type == output_quantized_type) { + return rewriter.create(op->getLoc(), output_tensor_type, + input); } - // Convert input to int32 tensor. - res_int32 = - rewriter.create(op->getLoc(), int32_tensor_type, input); - // Undo the input zero point. - Value input_zero_point = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - input_quantized_type.getZeroPoint()))); - res_int32 = rewriter.create( - op->getLoc(), int32_tensor_type, res_int32, input_zero_point, nullptr); - - // Adjust the scale. - const double effective_scale = - input_quantized_type.getScale() / result_quantized_type.getScale(); - int32_t effective_quantized_fraction; - int32_t effective_shift; - if (failed(quant::stablehlo::QuantizeMultiplier( - effective_scale, effective_quantized_fraction, effective_shift))) { - op->emitError("Invalid effective quantization scale."); - return failure(); - } - Value multiplier = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr( - static_cast(effective_quantized_fraction))); - // The effective_quantized_fraction value has been quantized by multiplying - // (1 << 15). So, we have to shift it back by (15 - effective_shift) to get - // the desired outcome. - Value total_shift = rewriter.create( + double merged_scale_fp = + input_quantized_type.getScale() / output_quantized_type.getScale(); + Value merged_scale = rewriter.create( op->getLoc(), - rewriter.getI32IntegerAttr(static_cast(15 - effective_shift))); - - // Apply the effective scale with rounding. - Value half = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr( - static_cast(1 << (14 - effective_shift)))); - res_int32 = rewriter.create( - op->getLoc(), int32_tensor_type, res_int32, multiplier, nullptr); - res_int32 = rewriter.create( - op->getLoc(), int32_tensor_type, res_int32, half, nullptr); - res_int32 = rewriter.create( - op->getLoc(), int32_tensor_type, res_int32, total_shift, nullptr); - - // Apply the output zero point. - Value output_zero_point = rewriter.create( - op->getLoc(), rewriter.getI32IntegerAttr(static_cast( - result_quantized_type.getZeroPoint()))); - res_int32 = rewriter.create( - op->getLoc(), int32_tensor_type, res_int32, output_zero_point, nullptr); + rewriter.getF32FloatAttr(static_cast(merged_scale_fp))); - return success(); + auto float_tensor_type = + input.getType().cast().clone(rewriter.getF32Type()); + Value output_float = + rewriter.create(op->getLoc(), float_tensor_type, input); + + output_float = rewriter.create( + op->getLoc(), float_tensor_type, output_float, merged_scale, nullptr); + + // Add merged_zp only when it is non-zero. + double merged_zp_fp = output_quantized_type.getZeroPoint() - + input_quantized_type.getZeroPoint() * merged_scale_fp; + if (merged_zp_fp != 0) { + Value merged_zp = rewriter.create( + op->getLoc(), + rewriter.getF32FloatAttr(static_cast(merged_zp_fp))); + output_float = rewriter.create( + op->getLoc(), float_tensor_type, output_float, merged_zp, nullptr); + } + + // Clamp output if the output integer bit-width <32. + if (output_tensor_type.getElementType().cast().getWidth() < 32) { + Value quantization_min = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(static_cast( + output_quantized_type.getStorageTypeMin()))); + Value quantization_max = rewriter.create( + op->getLoc(), rewriter.getF32FloatAttr(static_cast( + output_quantized_type.getStorageTypeMax()))); + // Clamp results by [quantization_min, quantization_max]. + output_float = rewriter.create( + op->getLoc(), float_tensor_type, quantization_min, output_float, + quantization_max); + } + + output_float = rewriter.create( + op->getLoc(), float_tensor_type, output_float); + return rewriter.create(op->getLoc(), output_tensor_type, + output_float); } class ConvertMHLOQuantToInt @@ -149,7 +149,7 @@ class ConvertUniformQuantizeOp mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto quantized_type = getElementTypeOrSelf(op.getResult().getType()) - .dyn_cast(); + .dyn_cast(); // Currently for activation, PTQ supports per-tensor quantization only, and // UniformQuantize op is only for activation. if (!quantized_type) { @@ -159,7 +159,7 @@ class ConvertUniformQuantizeOp auto input_element_type = getElementTypeOrSelf(op.getOperand().getType()); if (input_element_type.isF32()) { return matchAndRewriteQuantize(op, adaptor, rewriter, quantized_type); - } else if (input_element_type.isa()) { + } else if (input_element_type.isa()) { return matchAndRewriteRequantize(op, adaptor, rewriter, quantized_type); } return rewriter.notifyMatchFailure(op, "Unsupported input element type."); @@ -168,7 +168,7 @@ class ConvertUniformQuantizeOp LogicalResult matchAndRewriteQuantize( mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, ConversionPatternRewriter &rewriter, - const quant::UniformQuantizedType &quantized_type) const { + const UniformQuantizedType &quantized_type) const { Value scale = rewriter.create( op->getLoc(), rewriter.getF32FloatAttr(quantized_type.getScale())); Value zero_point = rewriter.create( @@ -201,71 +201,18 @@ class ConvertUniformQuantizeOp return success(); } - // Requantization is essentially dequantize --> quantize. - // - // Dequantize: (input - zp) * scale - // Quantize: input / scale + zp - // - // Hence, - // result = (input - input_zp) * input_scale / output_scale + output_zp - // - // This is simplified as: - // result = input * merged_scale + merged_zp - // where: - // merged_zp = output_zp - input_zp * merged_scale. - // merged_scale = input_scale / output_scale. LogicalResult matchAndRewriteRequantize( mhlo::UniformQuantizeOp op, mhlo::UniformQuantizeOpAdaptor adaptor, ConversionPatternRewriter &rewriter, - const quant::UniformQuantizedType &output_quantized_type) const { + const UniformQuantizedType &output_quantized_type) const { auto input_quantized_type = getElementTypeOrSelf(op.getOperand().getType()) - .cast(); - auto result_quantized_type = getElementTypeOrSelf(op.getResult().getType()) - .cast(); - - double merged_scale_fp = - input_quantized_type.getScale() / result_quantized_type.getScale(); - Value merged_scale = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_scale_fp))); - - auto res_float_tensor_type = - op.getOperand().getType().clone(rewriter.getF32Type()); - Value res_float = rewriter.create( - op->getLoc(), res_float_tensor_type, adaptor.getOperand()); - - res_float = rewriter.create( - op->getLoc(), res_float_tensor_type, res_float, merged_scale, nullptr); - - // Add merged_zp only when it is non-zero. - double merged_zp_fp = result_quantized_type.getZeroPoint() - - input_quantized_type.getZeroPoint() * merged_scale_fp; - if (merged_zp_fp != 0) { - Value merged_zp = rewriter.create( - op->getLoc(), - rewriter.getF32FloatAttr(static_cast(merged_zp_fp))); - res_float = rewriter.create( - op->getLoc(), res_float_tensor_type, res_float, merged_zp, nullptr); - } - - Value quantization_min = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMin()))); - Value quantization_max = rewriter.create( - op->getLoc(), rewriter.getF32FloatAttr(static_cast( - output_quantized_type.getStorageTypeMax()))); - - // Clamp results by [quantization_min, quantization_max]. - res_float = rewriter.create( - op->getLoc(), res_float_tensor_type, quantization_min, res_float, - quantization_max); - res_float = rewriter.create( - op->getLoc(), res_float_tensor_type, res_float); - - auto res_final_tensor_type = - res_float_tensor_type.clone(output_quantized_type.getStorageType()); - rewriter.replaceOpWithNewOp(op, res_final_tensor_type, - res_float); + .cast(); + rewriter.replaceOp( + op, Requantize(op, adaptor.getOperand(), input_quantized_type, + output_quantized_type, + op.getResult().getType().cast().clone( + output_quantized_type.getStorageType()), + rewriter)); return success(); } }; @@ -279,7 +226,7 @@ class ConvertUniformDequantizeOp mhlo::UniformDequantizeOp op, mhlo::UniformDequantizeOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto element_type = getElementTypeOrSelf(op.getOperand().getType()) - .dyn_cast(); + .dyn_cast(); // Currently for activation, PTQ supports per-tensor quantization only, and // UniformQuantize op is only for activation. if (!element_type) { @@ -317,18 +264,14 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { LogicalResult matchAndRewrite( mhlo::AddOp op, mhlo::AddOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto lhs_element_type = op.getLhs() - .getType() - .getElementType() - .dyn_cast(); - auto rhs_element_type = op.getRhs() - .getType() - .getElementType() - .dyn_cast(); + auto lhs_element_type = + op.getLhs().getType().getElementType().dyn_cast(); + auto rhs_element_type = + op.getRhs().getType().getElementType().dyn_cast(); auto result_element_type = op.getResult() .getType() .getElementType() - .dyn_cast(); + .dyn_cast(); // We only handle cases where lhs, rhs and results all have quantized // element type. @@ -347,20 +290,14 @@ class ConvertUniformQuantizedAddOp : public OpConversionPattern { // be the same as the result. // TODO: b/260280919 - Consider avoiding conversion to int32. Value lhs = adaptor.getLhs(); - Value lhs_int32_tensor; - if (failed(RequantizeWithoutClamping(op, lhs, res_int32_tensor_type, - lhs_element_type, result_element_type, - lhs_int32_tensor, rewriter))) { - return failure(); - } + Value lhs_int32_tensor = + Requantize(op, lhs, lhs_element_type, result_element_type, + res_int32_tensor_type, rewriter); Value rhs = adaptor.getRhs(); - Value rhs_int32_tensor; - if (failed(RequantizeWithoutClamping(op, rhs, res_int32_tensor_type, - rhs_element_type, result_element_type, - rhs_int32_tensor, rewriter))) { - return failure(); - } + Value rhs_int32_tensor = + Requantize(op, rhs, rhs_element_type, result_element_type, + res_int32_tensor_type, rewriter); Value zero_point = rewriter.create( op->getLoc(), rewriter.getI32IntegerAttr(static_cast( @@ -437,9 +374,9 @@ LogicalResult matchAndRewriteDotLikeHybridOp( // result = hybridOp(lhs, dequant(rhs)) Value lhs_float32_tensor = adaptor.getLhs(); Value rhs = adaptor.getRhs(); - quant::UniformQuantizedType rhs_element_type = + UniformQuantizedType rhs_element_type = getElementTypeOrSelf(op.getRhs().getType()) - .template cast(); + .template cast(); auto res_float32_tensor_type = op.getResult().getType().template cast(); auto rhs_float32_tensor_type = @@ -481,7 +418,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, // Calculate the output tensor shape. This is input tensor dims minus // contracting dims. - auto ranked_tensor = tensor.getType().dyn_cast(); + auto ranked_tensor = tensor.getType().cast(); llvm::SmallVector output_dims; for (int64_t i = 0; i < ranked_tensor.getRank(); ++i) { if (absl::c_count(reduction_dims, i) == 0) { @@ -492,7 +429,7 @@ Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, // Convert input tensor to output type since mhlo::Reduce only supports same // element type for input/output. tensor = builder.create( - loc, tensor.getType().dyn_cast().clone(output_element_type), + loc, tensor.getType().cast().clone(output_element_type), tensor); auto reducer_tensor_type = RankedTensorType::get({}, output_element_type); @@ -592,7 +529,7 @@ Value BroadcastZpContribution(OpBuilder &builder, Location loc, // zero-point-offset tensor to the final output tensor, and then do the // broadcast. auto zp_contribution_rank = - zp_contribution.getType().dyn_cast().getRank(); + zp_contribution.getType().cast().getRank(); llvm::SmallVector broadcast_dims; broadcast_dims.resize(zp_contribution_rank, 0); // Result tensor will have batching dims first, then LHS result dims, then @@ -615,7 +552,7 @@ Value BroadcastZpContribution(OpBuilder &builder, Location loc, } // Use broadcast_in_dim or dyanmic_broadcast_in_dim based on input shape // dynamism. - if (zp_contribution.getType().dyn_cast().hasStaticShape()) { + if (zp_contribution.getType().cast().hasStaticShape()) { zp_contribution = builder.create( loc, output_tensor_type, zp_contribution, DenseIntElementsAttr::get( @@ -742,13 +679,13 @@ Value CreateDotLikeKernel(OpBuilder &builder, Location loc, DenseIntElementsAttr::get( RankedTensorType::get({}, builder.getI8Type()), {static_cast(getElementTypeOrSelf(op.getLhs().getType()) - .dyn_cast() + .cast() .getZeroPoint())})); // Convert Padding attributes from mhlo::Convolution to mhlo::Pad. Note that // Padding is applied for spatial dimensions [1...rank-1) only for // mhlo::Convolution. But mhlo::Pad require those for all dimensions. Hence // we add 0 to the beginning and end of the padding vectors. - int64_t rank = lhs.getType().dyn_cast().getRank(); + int64_t rank = lhs.getType().cast().getRank(); llvm::SmallVector padding_low(rank, 0), padding_high(rank, 0), padding_interior(rank, 0); for (int64_t i = 1; i < rank - 1; ++i) { @@ -786,15 +723,12 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, ConversionPatternRewriter &rewriter) { // Lower Dot/DotGeneral UQ ops to DotGeneral int. // Assumes that operands and results are uq types. - auto lhs_element_quant_type = - getElementTypeOrSelf(op.getLhs().getType()) - .template dyn_cast(); - auto rhs_element_quant_type = - getElementTypeOrSelf(op.getRhs().getType()) - .template dyn_cast(); - auto res_element_quant_type = - getElementTypeOrSelf(op.getResult()) - .template dyn_cast(); + auto lhs_element_quant_type = getElementTypeOrSelf(op.getLhs().getType()) + .template dyn_cast(); + auto rhs_element_quant_type = getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto res_element_quant_type = getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); Value lhs = adaptor.getLhs(); Value rhs = adaptor.getRhs(); auto res_int32_tensor_type = @@ -837,8 +771,7 @@ LogicalResult matchAndRewriteDotLikeOp(DotLikeOp op, DotLikeOpAdaptor adaptor, // Skip zp_offset if it is 0. if (zp_offset) { auto zp_offset_float32_tensor_type = - zp_offset.getType().dyn_cast().clone( - rewriter.getF32Type()); + zp_offset.getType().cast().clone(rewriter.getF32Type()); zp_offset = rewriter.create( op->getLoc(), zp_offset_float32_tensor_type, zp_offset); zp_offset = rewriter.create( @@ -867,15 +800,12 @@ template FailureOr IsDotLikeOpHybrid(DotLikeOp op) { // Checks whether a dot-like op is hybrid by looking at input/output types. // Returns failure() when the type is not supported. - auto lhs_element_quant_type = - getElementTypeOrSelf(op.getLhs().getType()) - .template dyn_cast(); - auto rhs_element_quant_type = - getElementTypeOrSelf(op.getRhs().getType()) - .template dyn_cast(); - auto res_element_quant_type = - getElementTypeOrSelf(op.getResult()) - .template dyn_cast(); + auto lhs_element_quant_type = getElementTypeOrSelf(op.getLhs().getType()) + .template dyn_cast(); + auto rhs_element_quant_type = getElementTypeOrSelf(op.getRhs().getType()) + .template dyn_cast(); + auto res_element_quant_type = getElementTypeOrSelf(op.getResult()) + .template dyn_cast(); if (lhs_element_quant_type && rhs_element_quant_type && res_element_quant_type) { return false; @@ -996,8 +926,7 @@ bool IsConvNDHWC(const mhlo::ConvDimensionNumbersAttr &dims) { FailureOr VerifyConvolutionOp(mhlo::ConvolutionOp op) { // RHS (weight) must have zero zp. auto rhs_element_quant_type = - getElementTypeOrSelf(op.getRhs().getType()) - .template dyn_cast(); + getElementTypeOrSelf(op.getRhs().getType()).cast(); if (rhs_element_quant_type.getZeroPoint() != 0) { op->emitError("RHS UQ type must have zero zp."); return failure(); @@ -1074,15 +1003,15 @@ class ConvertGenericOp : public ConversionPattern { // Check that all operands and result uq types are the same. llvm::SmallVector uq_types; for (auto result_type : op->getResultTypes()) { - auto type = getElementTypeOrSelf(result_type) - .dyn_cast(); + auto type = + getElementTypeOrSelf(result_type).dyn_cast(); if (type) { uq_types.push_back(type); } } for (auto operand : op->getOperands()) { auto type = getElementTypeOrSelf(operand.getType()) - .dyn_cast(); + .dyn_cast(); if (type) { uq_types.push_back(type); } @@ -1097,11 +1026,10 @@ class ConvertGenericOp : public ConversionPattern { // type otherwise. llvm::SmallVector new_result_types; for (auto result_type : op->getResultTypes()) { - if (getElementTypeOrSelf(result_type) - .isa()) { + if (getElementTypeOrSelf(result_type).isa()) { new_result_types.push_back(result_type.cast().clone( getElementTypeOrSelf(result_type) - .cast() + .cast() .getStorageType())); } else { new_result_types.push_back(result_type); @@ -1122,7 +1050,7 @@ class UQTypeConverter : public TypeConverter { UQTypeConverter() { addConversion([](Type type) -> Type { auto to_legal_type = [](Type type) { - if (auto uq_type = dyn_cast(type)) { + if (auto uq_type = dyn_cast(type)) { return uq_type.getStorageType(); } return type; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index 9c2b88abddb8a5..468e5518abebfc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -158,9 +158,7 @@ class ConvertTfQuantToMhloIntTest : public ::testing::Test { // MHLO int passes. PassManager pm(module_op->getContext()); pm.addNestedPass(CreateConvertTFQuantTypesPass()); - pm.addNestedPass(CreateConvertTFQuantOpsToMHLOPass()); - pm.addNestedPass( - stablehlo::createConvertMHLOQuantToIntPass(false)); + AddQuantizationLoweringPasses(pm); CHECK(succeeded(pm.run(module_op.get()))); // Compile the program. return pjrt_client_->Compile(*module_op, xla::CompileOptions{}); @@ -392,5 +390,87 @@ func.func @main(%input: tensor<1x2xf32>, %filter: tensor<2x3xf32>) -> tensor<1x3 ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } +TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantize) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main(%input: tensor<4xf32>) -> tensor<4xf32> { + %input_scale = "tf.Const"() { value = dense<0.2235> : tensor } : () + -> tensor + %input_zp = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor + %output_scale = "tf.Const"() { value = dense<0.11> : tensor } : () + -> tensor + %output_zp = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + %0 = "tf.UniformQuantize"(%input, %input_scale, %input_zp) { + Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "", + quantization_axis = -1 : i64, quantization_max_val = 127 : i64, + quantization_min_val = -128 : i64 + } : (tensor<4xf32>, tensor, tensor) -> tensor<4x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = -1, input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64, + output_quantization_max_val = 127 : i64, output_quantization_min_val = -128 : i64 + } : ( + tensor<4x!tf_type.qint8>, tensor, tensor, tensor, tensor + ) -> tensor<4x!tf_type.qint8> + %2 = "tf.UniformDequantize"(%1, %output_scale, %output_zp) { + quantization_axis = -1 : i64, + quantization_min_val = -128 : i64, + quantization_max_val = 127 : i64 + } : (tensor<4x!tf_type.qint8>, tensor, tensor) -> tensor<4xf32> + return %2 : tensor<4xf32> +})mlir"; + auto input = xla::LiteralUtil::CreateR1({3.f, -10.f, -2.1f, 8.7f}); + ExecuteAndCompareResultsWithTfKernel(kProgram, {&input}); +} + +TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) { + constexpr absl::string_view kProgram = R"mlir( +func.func @main(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { + %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor } : () + -> tensor + %lhs_zp = "tf.Const"() { value = dense<42> : tensor } : () -> tensor + %rhs_scale = "tf.Const"() { value = dense<0.0239> : tensor } : () + -> tensor + %rhs_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %accum_scale = "tf.Const"() { value = dense<0.013> : tensor } : () + -> tensor + %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_lhs = "tf.UniformQuantize"(%lhs, %lhs_scale, %lhs_zp) { + Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT32", attr_map = "", + quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, + quantization_min_val = -2147483648 : i64 + } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint32> + %quant_rhs = "tf.UniformQuantize"(%rhs, %rhs_scale, %rhs_zp) { + Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT32", attr_map = "", + quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, + quantization_min_val = -2147483648 : i64 + } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint32> + %0 = "tf.UniformQuantizedAdd"( + %quant_lhs, %quant_rhs, %lhs_scale, %lhs_zp, %rhs_scale, + %rhs_zp, %accum_scale, %accum_zp + ) { + Tin = "tfdtype$DT_QINT32", Tout = "tfdtype$DT_QINT32", attr_map = "", + device = "", lhs_quantization_axis = -1 : i64, + lhs_quantization_max_val = 2147483647 : i64, lhs_quantization_min_val = -2147483648 : i64, + output_quantization_axis = -1 : i64, output_quantization_max_val = 2147483647 : i64, + output_quantization_min_val = -2147483648 : i64, rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 2147483647 : i64, rhs_quantization_min_val = -2147483648 : i64 + } : ( + tensor<2x2x!tf_type.qint32>, tensor<2x2x!tf_type.qint32>, tensor, + tensor, tensor, tensor, tensor, tensor + ) -> tensor<2x2x!tf_type.qint32> + %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, + quantization_max_val = 127 : i64 + } : (tensor<2x2x!tf_type.qint32>, tensor, tensor) -> tensor<2x2xf32> + return %output : tensor<2x2xf32> +})mlir"; + auto lhs = xla::LiteralUtil::CreateR2({{23.f, 12.f}, {-10.f, -3.f}}); + auto rhs = xla::LiteralUtil::CreateR2({{1.f, 2.f}, {-1.f, -3.f}}); + ExecuteAndCompareResultsWithTfKernel(kProgram, {&lhs, &rhs}); +} + } // namespace } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir index 9f9c11454c5c2a..36af198290dd2c 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/bridge/convert-mhlo-quant-to-int.mlir @@ -173,17 +173,11 @@ func.func @add_different_lhs_type( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // CHECK: %[[VAL1:.*]] = mhlo.convert %[[LHS:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor - // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor - // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor - // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor - // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor // CHECK-DAG: %[[RHS_32:.*]] = mhlo.convert %[[RHS:.*]] : (tensor) -> tensor // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor @@ -207,18 +201,11 @@ func.func @add_different_rhs_type( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // CHECK: %[[VAL0:.*]] = mhlo.convert %[[LHS:.*]] : (tensor) -> tensor - // CHECK: %[[VAL1:.*]] = mhlo.convert %[[RHS:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor - // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor - // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor - // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor - // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor // CHECK-DAG: %[[VAL7:.*]] = chlo.broadcast_add %[[LHS_32:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor @@ -239,29 +226,17 @@ func.func @add_different_res_type( %arg0: tensor>, %arg1: tensor> ) -> tensor> { - // CHECK: %[[VAL1:.*]] = mhlo.convert %[[LHS:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor - // CHECK: %[[VAL2:.*]] = chlo.broadcast_subtract %[[VAL1]], %[[INPUT_ZPS]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor - // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor - // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor - // CHECK: %[[VAL3:.*]] = chlo.broadcast_multiply %[[VAL2]], %[[MULTIPLIER]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL4:.*]] = chlo.broadcast_add %[[VAL3]], %[[HALF]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL5:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL4]], %[[TOTAL_SHIFT]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL5]], %[[OUTPUT_ZPS]] : (tensor, tensor) -> tensor - - // CHECK: %[[VAL6:.*]] = mhlo.convert %[[RHS:.*]] : (tensor) -> tensor - // CHECK-DAG: %[[INPUT_ZPS:.*]] = mhlo.constant dense<3> : tensor - // CHECK: %[[VAL7:.*]] = chlo.broadcast_subtract %[[VAL6]], %[[INPUT_ZPS]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[MULTIPLIER:.*]] = mhlo.constant dense<16384> : tensor - // CHECK-DAG: %[[TOTAL_SHIFT:.*]] = mhlo.constant dense<13> : tensor - // CHECK-DAG: %[[HALF:.*]] = mhlo.constant dense<4096> : tensor - // CHECK: %[[VAL8:.*]] = chlo.broadcast_multiply %[[VAL7]], %[[MULTIPLIER]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL9:.*]] = chlo.broadcast_add %[[VAL8]], %[[HALF]] : (tensor, tensor) -> tensor - // CHECK: %[[VAL10:.*]] = chlo.broadcast_shift_right_arithmetic %[[VAL9]], %[[TOTAL_SHIFT]] : (tensor, tensor) -> tensor - // CHECK-DAG: %[[OUTPUT_ZPS:.*]] = mhlo.constant dense<1> : tensor - // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[VAL10]], %[[OUTPUT_ZPS]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[LHS:.*]] = mhlo.convert %arg0 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[LHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[LHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor + + // CHECK-DAG: %[[COMBINED_SCALE:.*]] = mhlo.constant dense<2.000000e+00> : tensor + // CHECK-DAG: %[[RHS:.*]] = mhlo.convert %arg1 : (tensor) -> tensor + // CHECK-DAG: %[[MUL:.*]] = chlo.broadcast_multiply %[[RHS]], %[[COMBINED_SCALE]] : (tensor, tensor) -> tensor + // CHECK-DAG: %[[COMBINED_ZP:.*]] = mhlo.constant dense<-5.000000e+00> + // CHECK: %[[RHS_32_REQ:.*]] = chlo.broadcast_add %[[MUL]], %[[COMBINED_ZP]] : (tensor, tensor) -> tensor // CHECK-DAG: %[[RES_ZPS:.*]] = mhlo.constant dense<1> : tensor // CHECK-DAG: %[[VAL11:.*]] = chlo.broadcast_add %[[LHS_32_REQ:.*]], %[[RHS_32_REQ:.*]] : (tensor, tensor) -> tensor From d734e0d9ee4462e8e51e1341e2db760d7aa9da29 Mon Sep 17 00:00:00 2001 From: Dateng Lin Date: Mon, 23 Oct 2023 10:51:00 -0700 Subject: [PATCH 016/246] Removed the unused `device` argument. PiperOrigin-RevId: 575866461 --- tensorflow/compiler/jit/get_compiler_ir.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 37987cb55ecf38..b11d6f178f7db3 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -148,7 +148,7 @@ static StatusOr> BuildXlaCompilerArgumentFromTensorSpec( const FunctionBody* fbody, absl::Span must_be_constant_idxs, absl::Span inputs, - absl::Span variable_args, Device* device, + absl::Span variable_args, absl::Span flat_arg_shape_and_dtype) { TF_RET_CHECK(fbody != nullptr); auto& input_args = fbody->fdef.signature().input_arg(); @@ -326,7 +326,7 @@ StatusOr GetCompilerIr( if (compiler_arg_source == CompilerArgSource::TENSOR_SPEC) { args = BuildXlaCompilerArgumentFromTensorSpec(fbody, constant_arg_indices, - inputs, variable_infos, dev, + inputs, variable_infos, input_arg_shape_and_dtype); } else if (compiler_arg_source == CompilerArgSource::CONCRETE_INPUT) { args = XlaComputationLaunchContext::BuildXlaCompilerArguments( From affb9df19ef6d3a7755cfc133a97098acc8107a1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 10:54:50 -0700 Subject: [PATCH 017/246] [XLA:Runtime] Moved the FFT thunk to a new folder and removed unused dependencies, as part of a thunk clean up, and updated the necessary directories pointing to this thunk. #5758 PiperOrigin-RevId: 575867727 --- third_party/xla/xla/service/gpu/BUILD | 26 ++----------------- .../xla/service/gpu/ir_emitter_unnested.cc | 2 +- third_party/xla/xla/service/gpu/runtime/BUILD | 2 +- .../xla/xla/service/gpu/runtime/fft.cc | 2 +- third_party/xla/xla/service/gpu/runtime/fft.h | 2 +- .../xla/xla/service/gpu/runtime3/BUILD | 22 ++++++++++++++++ .../service/gpu/{ => runtime3}/fft_thunk.cc | 2 +- .../service/gpu/{ => runtime3}/fft_thunk.h | 6 ++--- 8 files changed, 32 insertions(+), 32 deletions(-) rename third_party/xla/xla/service/gpu/{ => runtime3}/fft_thunk.cc (99%) rename third_party/xla/xla/service/gpu/{ => runtime3}/fft_thunk.h (95%) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index bc03ccb1365c18..8f97780283b215 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -266,7 +266,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":backend_configs_cc", - ":fft_thunk", ":gemm_thunk", ":gpu_asm_opts_util", ":gpu_constants", @@ -309,6 +308,7 @@ cc_library( "//xla/service/gpu/fusions:thunk_util", "//xla/service/gpu/fusions:tiling_util", "//xla/service/gpu/runtime3:custom_call_thunk", + "//xla/service/gpu/runtime3:fft_thunk", "//xla/service/llvm_ir:buffer_assignment_util", "//xla/service/llvm_ir:dynamic_update_slice_util", "//xla/service/llvm_ir:fused_ir_emitter", @@ -931,7 +931,6 @@ cc_library( ":backend_configs_cc", ":buffer_allocations", ":cusolver_context", - ":fft_thunk", ":gemm_thunk", ":gpu_asm_opts_util", ":gpu_constants", @@ -976,6 +975,7 @@ cc_library( "//xla/service/gpu/runtime:executable", "//xla/service/gpu/runtime:support", "//xla/service/gpu/runtime3:custom_call_thunk", + "//xla/service/gpu/runtime3:fft_thunk", "//xla/stream_executor", "//xla/stream_executor:blas", "//xla/stream_executor:device_description", @@ -1139,28 +1139,6 @@ build_cub_sort_kernels( ] + xla_cub_deps()), ) -cc_library( - name = "fft_thunk", - srcs = ["fft_thunk.cc"], - hdrs = ["fft_thunk.h"], - visibility = ["//visibility:public"], - deps = [ - ":buffer_allocations", - ":thunk", - "//xla:types", - "//xla:util", - "//xla:xla_data_proto_cc", - "//xla/hlo/ir:hlo", - "//xla/service:buffer_assignment", - "//xla/stream_executor", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/strings", - "@com_google_absl//absl/strings:str_format", - "@local_tsl//tsl/platform:logging", - "@local_tsl//tsl/platform:status", - ], -) - cc_library( name = "gemm_rewriter", srcs = ["gemm_rewriter.cc"], diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 80adeb85b62ca5..3421768142ec23 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -88,7 +88,6 @@ limitations under the License. #include "xla/service/gpu/conditional_thunk.h" #include "xla/service/gpu/convolution_thunk.h" #include "xla/service/gpu/copy_thunk.h" -#include "xla/service/gpu/fft_thunk.h" #include "xla/service/gpu/for_thunk.h" #include "xla/service/gpu/fused_mha_thunk.h" #include "xla/service/gpu/fusions/fusions.h" @@ -116,6 +115,7 @@ limitations under the License. #include "xla/service/gpu/parallel_loop_emitter.h" #include "xla/service/gpu/replica_id_thunk.h" #include "xla/service/gpu/runtime3/custom_call_thunk.h" +#include "xla/service/gpu/runtime3/fft_thunk.h" #include "xla/service/gpu/sequential_thunk.h" #include "xla/service/gpu/thunk.h" #include "xla/service/gpu/while_thunk.h" diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index 38b5845855a7d3..e991df2efb5279 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -298,7 +298,7 @@ cc_library( "//xla/runtime:custom_call_registry", "//xla/runtime:executable", "//xla/runtime:state", - "//xla/service/gpu:fft_thunk", + "//xla/service/gpu/runtime3:fft_thunk", "//xla/stream_executor:fft", "//xla/translate/mhlo_to_hlo:attribute_exporter", ], diff --git a/third_party/xla/xla/service/gpu/runtime/fft.cc b/third_party/xla/xla/service/gpu/runtime/fft.cc index fda8f01e5ed1bb..a668a8e1ff8476 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft.cc +++ b/third_party/xla/xla/service/gpu/runtime/fft.cc @@ -21,8 +21,8 @@ limitations under the License. #include "xla/runtime/custom_call.h" #include "xla/runtime/executable.h" #include "xla/runtime/state.h" -#include "xla/service/gpu/fft_thunk.h" #include "xla/service/gpu/runtime/support.h" +#include "xla/service/gpu/runtime3/fft_thunk.h" #include "xla/stream_executor/fft.h" namespace xla { diff --git a/third_party/xla/xla/service/gpu/runtime/fft.h b/third_party/xla/xla/service/gpu/runtime/fft.h index bc45e310272dc7..7a34e2d1db4e62 100644 --- a/third_party/xla/xla/service/gpu/runtime/fft.h +++ b/third_party/xla/xla/service/gpu/runtime/fft.h @@ -20,7 +20,7 @@ limitations under the License. #include "xla/mlir/runtime/transforms/custom_call_encoding.h" #include "xla/runtime/custom_call_registry.h" -#include "xla/service/gpu/fft_thunk.h" +#include "xla/service/gpu/runtime3/fft_thunk.h" namespace xla { namespace gpu { diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index 0ff859f3f4b04c..e3ab3493aa7564 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -158,6 +158,28 @@ cc_library( ], ) +cc_library( + name = "fft_thunk", + srcs = ["fft_thunk.cc"], + hdrs = ["fft_thunk.h"], + visibility = ["//visibility:public"], + deps = [ + "//xla:types", + "//xla:util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/service:buffer_assignment", + "//xla/service/gpu:buffer_allocations", + "//xla/service/gpu:thunk", + "//xla/stream_executor", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:logging", + "@local_tsl//tsl/platform:status", + ], +) + cc_library( name = "triangular_solve_thunk", srcs = if_gpu_is_configured(["triangular_solve_thunk.cc"]), diff --git a/third_party/xla/xla/service/gpu/fft_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc similarity index 99% rename from third_party/xla/xla/service/gpu/fft_thunk.cc rename to third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc index 9198d43fdb7e8d..711ae990769c53 100644 --- a/third_party/xla/xla/service/gpu/fft_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "xla/service/gpu/fft_thunk.h" +#include "xla/service/gpu/runtime3/fft_thunk.h" #include diff --git a/third_party/xla/xla/service/gpu/fft_thunk.h b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h similarity index 95% rename from third_party/xla/xla/service/gpu/fft_thunk.h rename to third_party/xla/xla/service/gpu/runtime3/fft_thunk.h index 6b2224a7b13f01..4e0de39e1ad4ac 100644 --- a/third_party/xla/xla/service/gpu/fft_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/fft_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef XLA_SERVICE_GPU_FFT_THUNK_H_ -#define XLA_SERVICE_GPU_FFT_THUNK_H_ +#ifndef XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ +#define XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ #include @@ -97,4 +97,4 @@ Status RunFft(se::DeviceMemoryBase input, const Shape& input_shape, } // namespace gpu } // namespace xla -#endif // XLA_SERVICE_GPU_FFT_THUNK_H_ +#endif // XLA_SERVICE_GPU_RUNTIME3_FFT_THUNK_H_ From 360b6e3e127156303a10d419470edab9615eea83 Mon Sep 17 00:00:00 2001 From: Shan Han Date: Mon, 23 Oct 2023 11:30:27 -0700 Subject: [PATCH 018/246] Make costs of each batch accessible to RPC handler. PiperOrigin-RevId: 575879128 --- tensorflow/core/common_runtime/request_cost.h | 2 ++ .../core/common_runtime/request_cost_test.cc | 25 ++++++++++--- tensorflow/core/kernels/batching_util/BUILD | 1 + .../batching_util/batch_resource_base.cc | 13 +++++-- .../batching_util/batch_resource_base_test.cc | 36 +++++++++++-------- 5 files changed, 55 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/common_runtime/request_cost.h b/tensorflow/core/common_runtime/request_cost.h index c67827014663eb..ce4e5cc92845c2 100644 --- a/tensorflow/core/common_runtime/request_cost.h +++ b/tensorflow/core/common_runtime/request_cost.h @@ -50,6 +50,8 @@ class RequestCost { int64_t input_size = 0; // In this batch, the padding amount. int64_t padding_size = 0; + // Costs for processing this batch. + absl::flat_hash_map batch_costs; }; // Records the metrics of a batch. diff --git a/tensorflow/core/common_runtime/request_cost_test.cc b/tensorflow/core/common_runtime/request_cost_test.cc index 17f8820c7cd3cf..052f1eef600f0b 100644 --- a/tensorflow/core/common_runtime/request_cost_test.cc +++ b/tensorflow/core/common_runtime/request_cost_test.cc @@ -22,6 +22,8 @@ limitations under the License. namespace tensorflow { namespace { +using ::testing::ElementsAre; +using ::testing::FieldsAre; using ::testing::Pair; using ::testing::UnorderedElementsAre; @@ -53,13 +55,26 @@ TEST(RequestCostTest, RecordBatchMetrics) { RequestCost request_cost; request_cost.RecordBatchMetrics(RequestCost::BatchMetrics{ - /*processed_size=*/8, /*input_size=*/8, /*padding_size=*/0}); + /*processed_size=*/8, + /*input_size=*/8, + /*padding_size=*/0, + {{"gcu", absl::Milliseconds(80)}, {"tpu", absl::Milliseconds(160)}}}); request_cost.RecordBatchMetrics(RequestCost::BatchMetrics{ - /*processed_size=*/4, /*input_size=*/2, /*padding_size=*/1}); + /*processed_size=*/4, + /*input_size=*/2, + /*padding_size=*/1, + {{"gcu", absl::Milliseconds(40)}, {"tpu", absl::Milliseconds(80)}}}); - EXPECT_THAT(request_cost.GetBatchMetrics(), - testing::ElementsAre(testing::FieldsAre(8, 8, 0), - testing::FieldsAre(4, 2, 1))); + EXPECT_THAT( + request_cost.GetBatchMetrics(), + ElementsAre( + FieldsAre(8, 8, 0, + UnorderedElementsAre(Pair("gcu", absl::Milliseconds(80)), + Pair("tpu", absl::Milliseconds(160)))), + FieldsAre( + 4, 2, 1, + UnorderedElementsAre(Pair("gcu", absl::Milliseconds(40)), + Pair("tpu", absl::Milliseconds(80)))))); } } // namespace diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index c1c7d39955fcde..3c49cf57ca4bef 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -366,6 +366,7 @@ cc_library( "//tensorflow/core/profiler/lib:traceme_encode", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/core/util:incremental_barrier", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base.cc b/tensorflow/core/kernels/batching_util/batch_resource_base.cc index dcbc994df3b4b9..5724c5cdbc19b6 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base.cc @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "absl/container/flat_hash_map.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/blocking_counter.h" @@ -1096,15 +1097,21 @@ void BatchResourceBase::SplitBatchCostsAndRecordMetrics( // 2. Records the batch metrics in each task. const int64_t padding_size = processed_size - batch.size(); + absl::flat_hash_map batch_costs; + for (const auto& batch_cost_measurement : batch_cost_measurements) { + if (batch_cost_measurement->GetTotalCost() > absl::ZeroDuration()) { + batch_costs[batch_cost_measurement->GetCostType()] = + batch_cost_measurement->GetTotalCost(); + } + } for (int i = 0; i < batch.num_tasks(); i++) { RequestCost* request_cost = batch.task(i).request_cost; // Skip recording the metrics if the request_cost is null. if (!request_cost) continue; request_cost->RecordBatchMetrics(RequestCost::BatchMetrics{ - /*processed_size=*/processed_size, - /*input_size=*/static_cast(batch.task(i).size()), - /*padding_size=*/padding_size}); + processed_size, static_cast(batch.task(i).size()), + padding_size, batch_costs}); } } diff --git a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc index b7d4d6d3cb1dd7..18385670d2dabb 100644 --- a/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc +++ b/tensorflow/core/kernels/batching_util/batch_resource_base_test.cc @@ -75,10 +75,10 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnNoCostMeasurement) { /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); - EXPECT_THAT( - batch.task(0).request_cost->GetBatchMetrics(), - ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15))); + EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), + ::testing::ElementsAre(::testing::FieldsAre( + /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15, + ::testing::IsEmpty()))); } TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroCost) { @@ -95,10 +95,10 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroCost) { /*processed_size=*/16, batch); EXPECT_TRUE(batch.task(0).request_cost->GetCosts().empty()); - EXPECT_THAT( - batch.task(0).request_cost->GetBatchMetrics(), - ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15))); + EXPECT_THAT(batch.task(0).request_cost->GetBatchMetrics(), + ::testing::ElementsAre(::testing::FieldsAre( + /*processed_size=*/16, /*input_size=*/1, /*padding_size=*/15, + ::testing::IsEmpty()))); } TEST(SplitBatchCostsAndRecordMetricsTest, SkipOnZeroBatchSize) { @@ -154,7 +154,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitSingleCostType) { EXPECT_THAT( batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)))))); EXPECT_THAT( batch.task(1).request_cost->GetCosts(), UnorderedElementsAre(Pair("test_tpu_with_smear", absl::Milliseconds(90)), @@ -162,7 +163,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitSingleCostType) { EXPECT_THAT( batch.task(1).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)))))); } TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) { @@ -191,7 +193,9 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) { EXPECT_THAT( batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)), + Pair("test_gcu", absl::Milliseconds(200)))))); EXPECT_THAT( batch.task(1).request_cost->GetCosts(), @@ -202,7 +206,9 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitMultiCostTypes) { EXPECT_THAT( batch.task(1).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)), + Pair("test_gcu", absl::Milliseconds(200)))))); } TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { @@ -229,7 +235,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { EXPECT_THAT( batch.task(0).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/1, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)))))); EXPECT_THAT( batch.task(1).request_cost->GetCosts(), @@ -238,7 +245,8 @@ TEST(SplitBatchCostsAndRecordMetricsTest, SplitOnlyNonZeroCostTypes) { EXPECT_THAT( batch.task(1).request_cost->GetBatchMetrics(), ::testing::ElementsAre(::testing::FieldsAre( - /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10))); + /*processed_size=*/20, /*input_size=*/9, /*padding_size=*/10, + UnorderedElementsAre(Pair("test_tpu", absl::Milliseconds(100)))))); } } // namespace From a9e19d0de09d81e1a46d32f86ec04ab5203d7bf0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 11:37:41 -0700 Subject: [PATCH 019/246] Add a 3.12 Docker container. PiperOrigin-RevId: 575881215 --- .github/workflows/sigbuild-docker-branch.yml | 2 +- .../workflows/sigbuild-docker-presubmit.yml | 3 ++- .github/workflows/sigbuild-docker.yml | 2 +- .../ci_build/release/requirements_common.txt | 20 +++++++------- .../builder.devtoolset/build_devtoolset.sh | 2 +- .../devel.requirements.txt | 27 ++++++++++--------- .../tf_sig_build_dockerfiles/setup.python.sh | 1 + 7 files changed, 32 insertions(+), 25 deletions(-) diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml index 108fe471efa2db..9f842f9fb27c11 100644 --- a/.github/workflows/sigbuild-docker-branch.yml +++ b/.github/workflows/sigbuild-docker-branch.yml @@ -34,7 +34,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.9, python3.10, python3.11] + python-version: [python3.9, python3.10, python3.11, python3.12] steps: - name: Delete unnecessary tools folder run: rm -rf /opt/hostedtoolcache diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml index c61e65e7d834c0..03ae6f1dadf63f 100644 --- a/.github/workflows/sigbuild-docker-presubmit.yml +++ b/.github/workflows/sigbuild-docker-presubmit.yml @@ -32,7 +32,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.9, python3.10, python3.11] + python-version: [python3.9, python3.10, python3.11, python3.12] permissions: contents: read pull-requests: write @@ -87,6 +87,7 @@ jobs: message: | I pushed these containers: + - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.12` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.11` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.10` - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.9` diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml index ce9b99c494fc5e..5549f2995ac80f 100644 --- a/.github/workflows/sigbuild-docker.yml +++ b/.github/workflows/sigbuild-docker.yml @@ -37,7 +37,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [python3.9, python3.10, python3.11] + python-version: [python3.9, python3.10, python3.11, python3.12] steps: - name: Delete unnecessary tools folder run: rm -rf /opt/hostedtoolcache diff --git a/tensorflow/tools/ci_build/release/requirements_common.txt b/tensorflow/tools/ci_build/release/requirements_common.txt index e9a79b120b5672..da3ea7d765f113 100644 --- a/tensorflow/tools/ci_build/release/requirements_common.txt +++ b/tensorflow/tools/ci_build/release/requirements_common.txt @@ -5,18 +5,19 @@ absl-py ~= 1.0.0 astunparse ~= 1.6.3 flatbuffers ~= 23.5.26 google_pasta ~= 0.2 -h5py ~= 3.8.0 # Earliest version for Python 3.11 -ml_dtypes ~= 0.2 +h5py ~= 3.10.0 # Earliest version for Python 3.12 +ml_dtypes ~= 0.3.1 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower # to support TFX. Remove when Apache Beam upgrades to newer NumPy. numpy ~= 1.22.0; python_version < '3.11' -numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11 +numpy ~= 1.23.2; python_version == '3.11' # Earliest version for Python 3.11 +numpy ~= 1.26.0; python_version >= '3.12' # Earliest version for Python 3.12 opt_einsum ~= 3.3.0 protobuf ~= 3.20.3 # NOTE: Earliest version for Python 3.10 six ~= 1.16.0 termcolor ~= 2.1.1 -typing_extensions ~= 3.10.0.0 -wheel ~= 0.38.1 +typing_extensions ~= 4.8.0 +wheel ~= 0.41.2 wrapt ~= 1.14.1 # We need to pin the gast dependency exactly @@ -31,14 +32,15 @@ tb-nightly ~= 2.14.0.a tf-estimator-nightly ~= 2.14.0.dev # Test dependencies -grpcio ~= 1.49.1 # Earliest version for Python 3.11 -portpicker ~= 1.5.2 +grpcio ~= 1.59.0 # Earliest version for Python 3.12 +portpicker ~= 1.6.0 scipy ~= 1.7.2; python_version < '3.11' -scipy ~= 1.9.2; python_version >= '3.11' # Earliest version for Python 3.11 +scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11 +scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12 # This is usually vendored in setuptools but ensure it gets installed in CI anyway # No bound here, we prefer the one in setuptools packaging # For using Python 3.11 with Bazel 6 (b/286090018) -lit ~= 16.0.5.post0 +lit ~= 17.0.2 diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh b/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh index 6f4a566fcac306..b4c63677d7ae76 100755 --- a/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh +++ b/tensorflow/tools/tf_sig_build_dockerfiles/builder.devtoolset/build_devtoolset.sh @@ -184,7 +184,7 @@ esac # TODO(klimek): Automate linking in all non-gcc / non-kernel include # directories. mkdir -p "/${TARGET}/usr/include/x86_64-linux-gnu" -PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11") +PYTHON_VERSIONS=("python3.9" "python3.10" "python3.11" "python3.12") for v in "${PYTHON_VERSIONS[@]}"; do ln -s "/usr/local/include/${v}" "/${TARGET}/usr/include/x86_64-linux-gnu/${v}" done diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt index 9e12049cc6b4dc..62e73c996b1829 100644 --- a/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt +++ b/tensorflow/tools/tf_sig_build_dockerfiles/devel.requirements.txt @@ -8,19 +8,21 @@ absl-py ~= 1.0.0 astunparse ~= 1.6.3 flatbuffers ~= 23.5.26 google_pasta ~= 0.2 -h5py ~= 3.8.0 # Earliest version for Python 3.11 -ml_dtypes ~= 0.2 +h5py ~= 3.10.0 # Earliest version for Python 3.12 +ml_dtypes ~= 0.3.1 # TODO(b/262592253): Support older versions of NumPy for Python 3.10 and lower # to support TFX. Remove when Apache Beam upgrades to newer NumPy. numpy ~= 1.22.0; python_version < '3.11' -numpy ~= 1.23.2; python_version >= '3.11' # Earliest version for Python 3.11 +numpy ~= 1.23.2; python_version == '3.11' # Earliest version for Python 3.11 +numpy ~= 1.26.0; python_version >= '3.12' # Earliest version for Python 3.12 opt_einsum ~= 3.3.0 -packaging ~= 21.3 +packaging ~= 23.2 protobuf ~= 3.20.3 six ~= 1.16.0 termcolor ~= 2.1.1 -typing_extensions ~= 3.10.0.0 -wheel ~= 0.38.1 +typing_extensions ~= 4.8.0 +wheel ~= 0.41.2 +setuptools >= 68.2.2 wrapt ~= 1.14.1 # We need to pin the gast dependency exactly gast == 0.4.0 @@ -34,13 +36,14 @@ keras-nightly ~= 2.14.0.dev tb-nightly ~= 2.13.0.a tf-estimator-nightly ~= 2.14.0.dev # Test dependencies -grpcio ~= 1.49.1 # Earliest version for Python 3.11 -portpicker ~= 1.5.2 +grpcio ~= 1.59.0 # Earliest version for Python 3.12 +portpicker ~= 1.6.0 scipy ~= 1.7.2; python_version < '3.11' -scipy ~= 1.9.2; python_version >= '3.11' # Earliest version for Python 3.11 +scipy ~= 1.9.2; python_version == '3.11' # Earliest version for Python 3.11 +scipy ~= 1.11.3; python_version >= '3.12' # Earliest version for Python 3.12 # Required for TFLite import from JAX tests -jax ~= 0.3.25 -jaxlib ~= 0.3.25 # Earliest version for Python 3.11 +jax ~= 0.3.25; python_version <= '3.11' +jaxlib ~= 0.3.25; python_version <= '3.11' # Earliest version for Python 3.11 # Needs to be addressed. Unblocked 2.4 branchcut cl/338377048 PyYAML ~= 6.0 # For uploading @@ -52,4 +55,4 @@ lxml ~= 4.9.1 pylint ~= 2.13.9 # For using Python 3.11 with Bazel 6 (b/286090018) -lit ~= 16.0.5.post0 +lit ~= 17.0.2 diff --git a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh index 98d23d03eb0006..007ee34858ed2e 100755 --- a/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh +++ b/tensorflow/tools/tf_sig_build_dockerfiles/setup.python.sh @@ -59,6 +59,7 @@ fi curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py python3 get-pip.py python3 -m pip install --no-cache-dir --upgrade pip +python3 -m pip install -U setuptools # Disable the cache dir to save image space, and install packages python3 -m pip install --no-cache-dir -r $REQUIREMENTS -U From f2a52bd533aafa688a9e9582da083187d80caa00 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 11:58:28 -0700 Subject: [PATCH 020/246] Logs a warning if no default sharding strategy can be found for a given instruction, and reports the total # of such nodes. PiperOrigin-RevId: 575887322 --- .../experimental/auto_sharding/auto_sharding.cc | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc index d19e9205d53cbd..5be2d22e7c7516 100644 --- a/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc +++ b/third_party/xla/xla/hlo/experimental/auto_sharding/auto_sharding.cc @@ -2521,19 +2521,20 @@ AutoShardingSolverResult CallSolver( const std::vector& instructions = sequence.instructions(); // Serialize node costs + int num_nodes_without_default = 0; for (NodeIdx node_idx = 0; node_idx < request.num_nodes; ++node_idx) { const StrategyVector* strategies = leaf_strategies[node_idx]; auto instruction_name = instructions.at(strategies->instruction_id)->name(); request.instruction_names.push_back( absl::StrCat(instruction_name, " (id: ", node_idx, ")")); std::vector ci, di, mi, pi; - auto default_strategy = HloSharding::Replicate(); + std::optional default_strategy; auto iter = sharding_propagation_solution.find(instruction_name); if (iter != sharding_propagation_solution.end()) { CHECK(iter->second->has_sharding()) << iter->second->ToString(); default_strategy = iter->second->sharding(); if (strategies->tuple_element_idx) { - const auto& tuple_elements = default_strategy.tuple_elements(); + const auto& tuple_elements = default_strategy->tuple_elements(); CHECK_LT(*strategies->tuple_element_idx, tuple_elements.size()); default_strategy = tuple_elements.at(*strategies->tuple_element_idx); } @@ -2545,15 +2546,20 @@ AutoShardingSolverResult CallSolver( di.push_back(strategy.communication_cost + cost_graph.extra_node_costs_[node_idx][j]); mi.push_back(strategy.memory_cost); - // TODO(moffitt): Revisit the default strategy below, which is currently - // defined as the "trivial sharding" in hlo_sharding.h - pi.push_back(sharding == default_strategy ? 0.0 : 1.0); + pi.push_back(default_strategy && sharding == *default_strategy ? 0 : 1); + } + if (*std::min_element(pi.begin(), pi.end()) > 0) { + LOG(WARNING) << "No default strategy for {node_idx " << node_idx + << ", instruction ID " << strategies->instruction_id + << ", instruction name " << instruction_name << "}"; + ++num_nodes_without_default; } request.c.push_back(ci); request.d.push_back(di); request.m.push_back(mi); request.p.push_back(pi); } + LOG(INFO) << "Total nodes without default: " << num_nodes_without_default; // Serialize special edges that forces a alias pair have the same sharding // spec From d5d0deec1c29292fe3d1a23a8bd0382d20cfe0c5 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Mon, 23 Oct 2023 12:04:12 -0700 Subject: [PATCH 021/246] #tf-data destroys `element` before the `mutex_lock l(*mu)` is destroyed. This will guarantee that the iterators created inside of parallel interleave iterator get destroyed before itself. Otherwise, `CancelThreads(true)` might finish before `element`'s destructor is called because of `outstanding_threads_finished_cond_var_.wait(l);` PiperOrigin-RevId: 575888988 --- .../kernels/data/parallel_interleave_dataset_op.cc | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index aa24c0122ad463..2ae28298d81f27 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -931,7 +931,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { void CurrentWorkerThread(std::shared_ptr ctx) TF_LOCKS_EXCLUDED(mu_) { RecordStart(ctx.get()); + std::shared_ptr element; auto done = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Release the shared ownership so that + // the iterator managed by `element` is guaranteed destroyed + // before this class instance members. + element.reset(); RecordStop(ctx.get()); DecrementActiveWorkers(); DecrementCurrentActiveWorkers(); @@ -940,7 +945,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { }; while (true) { int element_index; - std::shared_ptr element; + element.reset(); // Find an element to process. { mutex_lock l(*mu_); @@ -1000,12 +1005,16 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { void FutureWorkerThread(std::shared_ptr ctx) TF_LOCKS_EXCLUDED(mu_) { RecordStart(ctx.get()); + std::shared_ptr element; auto done = [&]() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // Release the shared ownership so that + // the iterator managed by `element` is guaranteed destroyed + // before this class instance members. + element.reset(); RecordStop(ctx.get()); DecrementActiveWorkers(); DecrementOutstandingThreads(); }; - std::shared_ptr element; while (true) { { mutex_lock l(*mu_); From 4fd71da7305f7d461af2ad19a07e0271e0b9d542 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Mon, 23 Oct 2023 12:08:05 -0700 Subject: [PATCH 022/246] [stream_executor] Use `StreamIsCapturing` function. PiperOrigin-RevId: 575890152 --- .../xla/stream_executor/cuda/cuda_driver.cc | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc index 1c0e48a93441fd..bf16b27f56ca73 100644 --- a/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc +++ b/third_party/xla/xla/stream_executor/cuda/cuda_driver.cc @@ -37,7 +37,6 @@ limitations under the License. #include "absl/synchronization/notification.h" #include "third_party/gpus/cuda/include/cuda.h" #include "third_party/gpus/cuda/include/cuda_runtime_api.h" -#include "third_party/gpus/cuda/include/driver_types.h" #include "xla/stream_executor/cuda/cuda_diagnostics.h" #include "xla/stream_executor/gpu/gpu_driver.h" #include "xla/stream_executor/gpu/gpu_types.h" @@ -1591,21 +1590,15 @@ GpuDriver::CreateMemoryHandle(GpuContext* context, uint64_t bytes) { ScopedActivateContext activation(context); CUresult result; - // Check if the stream is doing graph capture. - cudaStreamCaptureStatus stream_capture_status; - cudaError_t err = - cudaStreamGetCaptureInfo(stream, &stream_capture_status, /*pId=*/nullptr); - if (err != cudaSuccess) { - LOG(ERROR) << "Failed to get stream capture info: " - << cudaGetErrorString(err); - return false; - } - // In graph capture mode we never have operations that access peer memory, so // we can always make a call to cuMemcpyDtoDAsync. - bool is_capturing = stream_capture_status == cudaStreamCaptureStatusActive; + tsl::StatusOr is_capturing = StreamIsCapturing(stream); + if (!is_capturing.ok()) { + LOG(ERROR) << is_capturing.status().message(); + return false; + } - if ((gpu_dst == 0 || gpu_src == 0) || is_capturing) { + if ((gpu_dst == 0 || gpu_src == 0) || (*is_capturing)) { // CreatedContexts::GetAnyContext() doesn't works when ptr == 0. // This happens when the size is 0. result = cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream); From 9e31959b058e19a7f63f5dabcec1418a80c6cc74 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 23 Oct 2023 12:46:15 -0700 Subject: [PATCH 023/246] Remove dependency from "GetMinibatchSplitsWithPhysicalReplica" to a flag. The flag could have a different value in different processes, which would invalidate the inferred shape. The dependency also inflates the size of the op library, by adding a dependency on a kernel library (and transitively on MLIR), and pulls in unrelated ops which complicates wrapper generation. PiperOrigin-RevId: 575900618 --- tensorflow/core/tpu/ops/BUILD | 1 - .../core/tpu/ops/sparse_core_preprocess_ops.cc | 18 ++---------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD index 5bb259374fa0b3..e49c75d54d269d 100644 --- a/tensorflow/core/tpu/ops/BUILD +++ b/tensorflow/core/tpu/ops/BUILD @@ -186,7 +186,6 @@ cc_library( deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", - "//tensorflow/core/tpu/kernels:sparse_core_ops_utils", "@local_xla//xla:util", ], alwayslink = 1, diff --git a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc index d25c9d1489ae3e..a6d1fa90dca82d 100644 --- a/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc +++ b/tensorflow/core/tpu/ops/sparse_core_preprocess_ops.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/tpu/kernels/sparse_core_ops_utils.h" #include "tsl/platform/errors.h" namespace tensorflow { @@ -131,22 +130,9 @@ REGISTER_OP("GetMinibatchSplitsWithPhysicalReplica") c->set_output(0, c->UnknownShapeOfRank(1)); c->set_output(1, c->UnknownShapeOfRank(1)); c->set_output(2, c->UnknownShapeOfRank(1)); - int32 num_replica; - TF_RETURN_IF_ERROR(c->GetAttr("num_replica", &num_replica)); - - int32 num_sc_per_chip; - TF_RETURN_IF_ERROR(c->GetAttr("num_sc_per_chip", &num_sc_per_chip)); - - const int max_division_level = GetMinibatchMaxDivisionLevel(); - - const int num_physical_replica = num_replica * num_sc_per_chip; - - const int32 kMaxDivisions = 1 << max_division_level; - c->set_output(3, c->Scalar()); - c->set_output( - 4, c->MakeShape( - {num_physical_replica * kMaxDivisions * num_sc_per_chip + 1})); + // Depends on max division level, which is currently passed by flag. + c->set_output(4, c->UnknownShapeOfRank(1)); c->set_output(5, c->Scalar()); c->set_output(6, c->Scalar()); return OkStatus(); From 1c22230666d36c520ee6c41fbb8be8f5e90097e3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 13:02:27 -0700 Subject: [PATCH 024/246] Fixes typo in function args comment PiperOrigin-RevId: 575904949 --- tensorflow/python/ops/variables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index d99b7e114aa372..5208dd1c8229ae 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -352,7 +352,7 @@ def __init__(self, variable and return the Tensor for the projected value (which must have the same shape). Constraints are not safe to use when doing asynchronous distributed training. - synchronization: Indicates when a distributed a variable will be + synchronization: Indicates when a distributed variable will be aggregated. Accepted values are constants defined in the class `tf.VariableSynchronization`. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses when to From 10ac33ac80c5355f72fa0916cf87a82031d96599 Mon Sep 17 00:00:00 2001 From: Mason Chang Date: Mon, 23 Oct 2023 13:19:56 -0700 Subject: [PATCH 025/246] Move error logging into lower cluster to runtime ops PiperOrigin-RevId: 575910019 --- .../tensorflow/transforms/host_runtime/BUILD | 7 +++ .../lower_cluster_to_runtime_ops.cc | 46 ++++++++++++++++++- .../lower_cluster_to_runtime_ops_test.cc | 12 +++++ .../compiler/mlir/tf2xla/api/v1/cluster_tf.cc | 7 --- .../compiler/mlir/tf2xla/api/v2/cluster_tf.cc | 7 +-- 5 files changed, 66 insertions(+), 13 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index fcf2b1bca58552..359dd5c4624712 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -29,14 +29,20 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", "//tensorflow/core:framework", "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core/platform:error_payloads", + "//tensorflow/core/platform:status", "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", + "@local_tsl//tsl/platform:error_logging", + "@local_tsl//tsl/platform:errors", ], ) @@ -56,6 +62,7 @@ tf_cc_test( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:test", + "//tensorflow/core/lib/monitoring:cell_reader", "//tensorflow/core/platform:resource_loader", "//tensorflow/core/tpu:tpu_defs", "@com_google_absl//absl/status", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc index 4dfe7c1fc19179..cba2b05cc2e78a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/log/log.h" +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -29,10 +30,15 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/data_dumper_logger_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/framework/metrics.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/error_payloads.h" +#include "tensorflow/core/platform/status.h" #include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/util/debug_data_dumper.h" #include "tsl/framework/device_type.h" +#include "tsl/platform/error_logging.h" +#include "tsl/platform/errors.h" namespace tensorflow { namespace tfrt_compiler { @@ -111,6 +117,39 @@ void CreateNonTPULowerClusterToRuntimeOpsPassPipeline( AddNonTPULowerClusterToRuntimeOpsPassPipeline(pm, /*module_name=*/""); } +// TODO(b/306728216): Move this out of the Bridge component and into a Host +// runtime component. +tensorflow::Status RecordIfErrorStatus(const std::string error_prefix, + tsl::DeviceType device_type, + absl::Status status) { + if (status.ok()) { + return status; + } + + VLOG(2) << error_prefix << " " << status; + tensorflow::metrics::UpdateTfMlirBridgeFirstPhaseCounter( + device_type.type_string(), /*bridge_version=*/"v2", + /*fallback_enabled=*/false, + /*result=*/"failure"); + + constexpr char kBridgeComponent[] = "TFXLABridge"; + std::string bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_TPU_BRIDGE"; + + tsl::OkOrSetErrorCounterPayload( + tensorflow::core::platform::ErrorSourceProto::MLIR_BRIDGE_PHASE_1, + status); + + if (device_type != DeviceType(DEVICE_TPU_XLA_JIT)) { + bridge_subcomponent = "TFXLA_PHASE_ONE_MLIR_CPU/GPU_BRIDGE"; + } + + tsl::error_logging::Log(kBridgeComponent, bridge_subcomponent, + status.ToString()) + .IgnoreError(); + + return status; +} + absl::Status RunLowerClusterToRuntimeOpsPassPipeline( mlir::ModuleOp module, tsl::DeviceType xla_device_type, llvm::StringRef module_name) { @@ -154,7 +193,12 @@ absl::Status RunLowerClusterToRuntimeOpsPassPipeline( module, llvm::StringRef(), &runtime_lowering); } - return diag_handler.ConsumeStatus(); + auto result_status = diag_handler.ConsumeStatus(); + TF_RETURN_IF_ERROR( + RecordIfErrorStatus(/*error_prefix=*/"lower_cluster_to_runtime", + xla_device_type, result_status)); + + return absl::OkStatus(); } // TODO(b/305211853): Unify the CPU/TPU/GPU Execution Ops and thus these two diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc index b58a13d679adb6..e0e376e200e992 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" +#include #include #include @@ -33,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/register_common_dialects.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/lib/monitoring/cell_reader.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/resource_loader.h" #include "tensorflow/core/platform/test.h" @@ -51,6 +53,7 @@ using mlir::ModuleOp; using mlir::OpPassManager; using mlir::OwningOpRef; using mlir::func::FuncOp; +using ::tensorflow::monitoring::testing::CellReader; using tsl::DeviceType; std::string TestDataPath() { @@ -58,6 +61,9 @@ std::string TestDataPath() { "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/testdata/"); } +static constexpr char kCompilationStreamz[] = + "/tensorflow/core/tf_mlir_bridge_first_phase_count"; + class LowerClusterToRuntimeOpsTest : public ::testing::Test { public: LowerClusterToRuntimeOpsTest() { @@ -154,11 +160,17 @@ TEST_F(LowerClusterToRuntimeOpsTest, LowersClusterOpsGPU) { } TEST_F(LowerClusterToRuntimeOpsTest, ErrorsWithBadCluster) { + CellReader compilation_status(kCompilationStreamz); + TF_ASSERT_OK(CreateMlirModule("malformed_cluster.mlir")); EXPECT_FALSE(RunLowerClusterToRuntimeOpsPassPipeline( *mlir_module_, DeviceType(DEVICE_TPU_XLA_JIT)) .ok()); + + EXPECT_EQ(compilation_status.Delta("XLA_TPU_JIT", "v2", "fallback_disabled", + "failure"), + 1); } TEST_F(LowerClusterToRuntimeOpsTest, DumpsPipelinePasses) { diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc index 105b9016bb3c24..91e6c179ee8aae 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/cluster_tf.cc @@ -227,13 +227,6 @@ absl::Status RunLowerToRuntimeOpsOnSubmodule(ModuleOp parent_module, is_in_fallback_enabled_mode, num_submodules_error)); } - if (!runtime_lowering_status.ok()) { - TF_RETURN_IF_ERROR(RecordStatusIfError( - /*error_prefix=*/ - "Errored running lowering cluster ops to runtime ops pipeline:", - is_in_fallback_enabled_mode, runtime_lowering_status)); - } - return absl::OkStatus(); } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc index b07d0340705e01..c9071c9f51b072 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v2/cluster_tf.cc @@ -200,12 +200,9 @@ tensorflow::Status RunNonTPUBridge(ModuleOp module, is_in_fallback_enabled_mode, device_type, clustering_status)); - Status runtime_lowering_status = + TF_RETURN_IF_ERROR( tensorflow::tfrt_compiler::RunLowerClusterToRuntimeOpsPassPipeline( - module, tsl::DeviceType("XLA_GPU_JIT"), module_name); - TF_RETURN_IF_ERROR(RecordIfErrorStatus(/*error_prefix=*/"runtime_lowering_v2", - is_in_fallback_enabled_mode, - device_type, runtime_lowering_status)); + module, tsl::DeviceType("XLA_GPU_JIT"), module_name)); Status export_status = tensorflow::tf2xla::v2::ExportFromTensorflowDialectToExecutor( From 0ed832a9ea2e3213539169322769430c85703141 Mon Sep 17 00:00:00 2001 From: David Dunleavy Date: Mon, 23 Oct 2023 13:26:40 -0700 Subject: [PATCH 026/246] Refactor xla.bzl to not repeat deps list in xla_cc_{binary,test} In preparation to create `xla_protos_all` to simplify further PiperOrigin-RevId: 575911998 --- third_party/xla/xla/xla.bzl | 123 +++++++++++++----------------------- 1 file changed, 45 insertions(+), 78 deletions(-) diff --git a/third_party/xla/xla/xla.bzl b/third_party/xla/xla/xla.bzl index 00939bee30df77..d091406bcff59c 100644 --- a/third_party/xla/xla/xla.bzl +++ b/third_party/xla/xla/xla.bzl @@ -6,7 +6,7 @@ load( ) load( "@local_tsl//tsl:tsl.bzl", - "if_tsl_link_protobuf", + "if_oss", "tsl_copts", _tsl_clean_dep = "clean_dep", ) @@ -39,87 +39,54 @@ ORC_JIT_MEMORY_MAPPER_TARGETS = [] def xla_py_test_deps(): return [] -def xla_cc_binary(deps = None, copts = tsl_copts(), **kwargs): - if not deps: - deps = [] +# TODO(ddunleavy): some of these should be removed from here and added to +# specific targets. +# We actually shouldn't need this anymore post vendoring. If we build without +# `framework_shared_object` in the bazelrc all of this should be able to go +# away. The problem is making sure that all these impl deps are `if_static`'d +# appropriately throughout XLA. +_XLA_SHARED_OBJECT_SENSITIVE_DEPS = if_oss([_tsl_clean_dep("@com_google_protobuf//:protobuf")]) + [ + clean_dep("//xla:xla_proto_cc_impl"), + clean_dep("//xla:xla_data_proto_cc_impl"), + clean_dep("//xla/service:hlo_proto_cc_impl"), + clean_dep("//xla/service:buffer_assignment_proto_cc_impl"), + clean_dep("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"), + clean_dep("//xla/service/gpu:backend_configs_cc_impl"), + clean_dep("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"), + clean_dep("//xla/stream_executor:device_description_proto_cc_impl"), + clean_dep("//xla/stream_executor:device_id_utils"), + clean_dep("//xla/stream_executor:stream_executor_impl"), + clean_dep("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"), + clean_dep("//xla/stream_executor/gpu:gpu_init_impl"), + clean_dep("@local_tsl//tsl/profiler/utils:time_utils_impl"), + clean_dep("@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl"), + clean_dep("@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl"), + clean_dep("@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), + clean_dep("@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl"), + clean_dep("//xla:autotune_results_proto_cc_impl"), + clean_dep("//xla:autotuning_proto_cc_impl"), + clean_dep("@local_tsl//tsl/protobuf:protos_all_cc_impl"), + clean_dep("@local_tsl//tsl/platform:env_impl"), + clean_dep("@local_tsl//tsl/framework:allocator"), + clean_dep("@local_tsl//tsl/framework:allocator_registry_impl"), + clean_dep("@local_tsl//tsl/util:determinism"), +] + if_cuda_is_configured([ + clean_dep("//xla/stream_executor/cuda:cuda_stream"), + clean_dep("//xla/stream_executor/cuda:all_runtime"), + clean_dep("//xla/stream_executor/cuda:stream_executor_cuda"), +]) + if_rocm_is_configured([ + clean_dep("//xla/stream_executor/gpu:gpu_stream"), + clean_dep("//xla/stream_executor/rocm:all_runtime"), + clean_dep("//xla/stream_executor/rocm:stream_executor_rocm"), +]) - # TODO(ddunleavy): some of these should be removed from here and added to - # specific targets. - deps += [ - _tsl_clean_dep("@com_google_protobuf//:protobuf"), - "//xla:xla_proto_cc_impl", - "//xla:xla_data_proto_cc_impl", - "//xla/service:hlo_proto_cc_impl", - "//xla/service:buffer_assignment_proto_cc_impl", - "//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl", - "//xla/service/gpu:backend_configs_cc_impl", - "//xla/service/gpu/model:hlo_op_profile_proto_cc_impl", - "//xla/stream_executor:device_description_proto_cc_impl", - "//xla/stream_executor:stream_executor_impl", - "//xla/stream_executor/gpu:gpu_init_impl", - "@local_tsl//tsl/platform:env_impl", - "@local_tsl//tsl/platform:tensor_float_32_utils", - "@local_tsl//tsl/profiler/utils:time_utils_impl", - "@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl", - "@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl", - "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl", - "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl", - "//xla:autotune_results_proto_cc_impl", - "//xla:autotuning_proto_cc_impl", - "@local_tsl//tsl/protobuf:protos_all_cc_impl", - "@local_tsl//tsl/framework:allocator", - "@local_tsl//tsl/framework:allocator_registry_impl", - "@local_tsl//tsl/util:determinism", - ] - native.cc_binary(deps = deps, copts = copts, **kwargs) +def xla_cc_binary(deps = [], copts = tsl_copts(), **kwargs): + native.cc_binary(deps = deps + _XLA_SHARED_OBJECT_SENSITIVE_DEPS, copts = copts, **kwargs) -def xla_cc_test( - name, - deps = [], - **kwargs): +def xla_cc_test(name, deps = [], **kwargs): native.cc_test( name = name, - deps = deps + if_tsl_link_protobuf( - [], - [ - _tsl_clean_dep("@com_google_protobuf//:protobuf"), - # TODO(zacmustin): remove these in favor of more granular dependencies in each test. - clean_dep("//xla:xla_proto_cc_impl"), - clean_dep("//xla:xla_data_proto_cc_impl"), - clean_dep("//xla/service:hlo_proto_cc_impl"), - clean_dep("//xla/service:buffer_assignment_proto_cc_impl"), - clean_dep("//xla/service/memory_space_assignment:memory_space_assignment_proto_cc_impl"), - clean_dep("//xla/service/gpu:backend_configs_cc_impl"), - clean_dep("//xla/service/gpu/model:hlo_op_profile_proto_cc_impl"), - clean_dep("//xla/stream_executor:device_description_proto_cc_impl"), - clean_dep("//xla/stream_executor:device_id_utils"), - clean_dep("//xla/stream_executor:stream_executor_impl"), - clean_dep("//xla/stream_executor/gpu:gpu_cudamallocasync_allocator"), - clean_dep("//xla/stream_executor/gpu:gpu_init_impl"), - clean_dep("@local_tsl//tsl/profiler/utils:time_utils_impl"), - clean_dep("@local_tsl//tsl/profiler/backends/cpu:annotation_stack_impl"), - clean_dep("@local_tsl//tsl/profiler/backends/cpu:traceme_recorder_impl"), - clean_dep("@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc_impl"), - clean_dep("@local_tsl//tsl/profiler/protobuf:xplane_proto_cc_impl"), - clean_dep("//xla:autotune_results_proto_cc_impl"), - clean_dep("//xla:autotuning_proto_cc_impl"), - clean_dep("@local_tsl//tsl/protobuf:protos_all_cc_impl"), - clean_dep("@local_tsl//tsl/platform:env_impl"), - clean_dep("@local_tsl//tsl/framework:allocator"), - clean_dep("@local_tsl//tsl/framework:allocator_registry_impl"), - clean_dep("@local_tsl//tsl/util:determinism"), - ], - ) + - if_cuda_is_configured([ - clean_dep("//xla/stream_executor/cuda:cuda_stream"), - clean_dep("//xla/stream_executor/cuda:all_runtime"), - clean_dep("//xla/stream_executor/cuda:stream_executor_cuda"), - ]) + - if_rocm_is_configured([ - clean_dep("//xla/stream_executor/gpu:gpu_stream"), - clean_dep("//xla/stream_executor/rocm:all_runtime"), - clean_dep("//xla/stream_executor/rocm:stream_executor_rocm"), - ]), + deps = deps + _XLA_SHARED_OBJECT_SENSITIVE_DEPS, exec_properties = tf_exec_properties(kwargs), **kwargs ) From 57b76a1da45ba56d309e00f263fc7fdc214d7904 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Mon, 23 Oct 2023 14:00:36 -0700 Subject: [PATCH 027/246] Remove dependencies on other op libraries from //third_party/tensorflow/core/tpu/ops:sparse_core_ops. These dependencies are unnecessary, and they can introduce duplicate op wrappers when we generate Python bindings. PiperOrigin-RevId: 575921758 --- tensorflow/core/tpu/ops/BUILD | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tensorflow/core/tpu/ops/BUILD b/tensorflow/core/tpu/ops/BUILD index e49c75d54d269d..b529d004d2540f 100644 --- a/tensorflow/core/tpu/ops/BUILD +++ b/tensorflow/core/tpu/ops/BUILD @@ -201,12 +201,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/tpu:tpu_embedding_optimization_parameters_utils", - "//tensorflow/core/tpu/ops:tpu_compile_op", - "//tensorflow/core/tpu/ops:tpu_embedding_ops", - "//tensorflow/core/tpu/ops:tpu_execute_op", - "//tensorflow/core/tpu/ops:tpu_handle_to_key_op", - "//tensorflow/core/tpu/ops:tpu_partitioned_ops", - "//tensorflow/core/tpu/ops:tpu_round_robin_op", "@com_google_absl//absl/strings", ], alwayslink = 1, From 2328ddf3b3439e19b68cb3650c2c45573e4706c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 14:36:43 -0700 Subject: [PATCH 028/246] Add visibility for //learning/metadata/artifactoid/cc PiperOrigin-RevId: 575932165 --- tensorflow/cc/saved_model/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index b09ebe0bd0944c..935b37b37aa5c0 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -522,6 +522,7 @@ cc_library( visibility = [ "//learning/brain/contrib/hub/server/distro:__subpackages__", "//learning/brain/contrib/tpu_modeling:__subpackages__", + "//learning/metadata/artifactoid/cc:__subpackages__", "//learning/tfx/pipeline/util:__subpackages__", "//tensorflow/python/saved_model:__subpackages__", ], From be859e03626080af41b54b01bfdb98a766be4c37 Mon Sep 17 00:00:00 2001 From: "Ralf W. Grosse-Kunstleve" Date: Mon, 23 Oct 2023 15:11:38 -0700 Subject: [PATCH 029/246] Sprawling .pyi updates related to pybind11 PRs #4876. PiperOrigin-RevId: 575941719 --- tensorflow/python/client/_pywrap_tf_session.pyi | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/client/_pywrap_tf_session.pyi b/tensorflow/python/client/_pywrap_tf_session.pyi index a453092aeee274..b15067ea3ce38b 100644 --- a/tensorflow/python/client/_pywrap_tf_session.pyi +++ b/tensorflow/python/client/_pywrap_tf_session.pyi @@ -95,7 +95,7 @@ class OpsById: def __contains__(self, arg0: object) -> bool: ... def __delitem__(self, arg0: int) -> None: ... def __getitem__(self, arg0: int) -> object: ... - def __iter__(self) -> Iterator: ... + def __iter__(self) -> Iterator[int]: ... def __len__(self) -> int: ... def __setitem__(self, arg0: int, arg1: object) -> None: ... @@ -111,7 +111,7 @@ class OpsByName: def __contains__(self, arg0: object) -> bool: ... def __delitem__(self, arg0: str) -> None: ... def __getitem__(self, arg0: str) -> object: ... - def __iter__(self) -> Iterator: ... + def __iter__(self) -> Iterator[str]: ... def __len__(self) -> int: ... def __setitem__(self, arg0: str, arg1: object) -> None: ... From cb1f99598a2e9505ddabdffcf050fde1961ceabe Mon Sep 17 00:00:00 2001 From: Austin Anderson Date: Mon, 23 Oct 2023 17:42:18 -0700 Subject: [PATCH 030/246] Add nightly libtpu download to newer scripts Trying to get these in sync with the current jobs, which are failing, probably because of the wrong .so being used. PiperOrigin-RevId: 575978433 --- ci/official/envs/ci_default | 1 + ci/official/envs/nightly_linux_x86_tpu_py310 | 3 +-- ci/official/envs/nightly_linux_x86_tpu_py311 | 3 +-- ci/official/envs/nightly_linux_x86_tpu_py39 | 3 +-- ci/official/wheel.sh | 8 ++++++++ 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/ci/official/envs/ci_default b/ci/official/envs/ci_default index 183b2048ce1a5f..eb7938c8b3449d 100644 --- a/ci/official/envs/ci_default +++ b/ci/official/envs/ci_default @@ -16,6 +16,7 @@ TFCI_NIGHTLY_UPDATE_VERSION_ENABLE= TFCI_NVIDIA_SMI_ENABLE= TFCI_OUTPUT_DIR=build_output TFCI_LIBTPU_DOWNLOAD_ENABLE=0 +TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=0 TFCI_LIBTPU_DOWNLOAD_URL= TFCI_UPLOAD_LIB_ENABLE= TFCI_UPLOAD_LIB_LATEST_ENABLE= diff --git a/ci/official/envs/nightly_linux_x86_tpu_py310 b/ci/official/envs/nightly_linux_x86_tpu_py310 index da77f5ab3668ae..1f5b5eea178fa7 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py310 +++ b/ci/official/envs/nightly_linux_x86_tpu_py310 @@ -6,6 +6,5 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_ENABLE=1 -TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so +TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py311 b/ci/official/envs/nightly_linux_x86_tpu_py311 index 8f95c5df576b18..1a85c11048dadc 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py311 +++ b/ci/official/envs/nightly_linux_x86_tpu_py311 @@ -6,6 +6,5 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_ENABLE=1 -TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so +TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 diff --git a/ci/official/envs/nightly_linux_x86_tpu_py39 b/ci/official/envs/nightly_linux_x86_tpu_py39 index b75e450cda1a0a..4e88352241972f 100644 --- a/ci/official/envs/nightly_linux_x86_tpu_py39 +++ b/ci/official/envs/nightly_linux_x86_tpu_py39 @@ -6,6 +6,5 @@ TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_tpu TFCI_BUILD_PIP_PACKAGE_ARGS=(--tpu --nightly_flag) TFCI_DOCKER_IMAGE=tensorflow/build:latest-python${TFCI_PYTHON_VERSION} TFCI_DOCKER_REBUILD_ARGS=(--build-arg PYTHON_VERSION=$TFCI_PYTHON_VERSION --target=devel tools/tf_sig_build_dockerfiles) -TFCI_LIBTPU_DOWNLOAD_ENABLE=1 -TFCI_LIBTPU_DOWNLOAD_URL=https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.8.0/libtpu.so +TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE=1 TFCI_NIGHTLY_UPDATE_VERSION_ENABLE=1 diff --git a/ci/official/wheel.sh b/ci/official/wheel.sh index e3e569f4c112c2..f9faf1eeff13e8 100755 --- a/ci/official/wheel.sh +++ b/ci/official/wheel.sh @@ -29,6 +29,14 @@ fi if [[ "$TFCI_LIBTPU_DOWNLOAD_ENABLE" == 1 ]]; then wget -P ./tensorflow/lib/ "$TFCI_LIBTPU_DOWNLOAD_URL" fi +if [[ "$TFCI_LIBTPU_DOWNLOAD_NIGHTLY_ENABLE" == 1 ]]; then + # For nightly jobs, libtpu.so comes from the latest nightly libtpu build. + # Note: expects a working wheel for today + DATE=(TZ='America/Los_Angeles' date '+%Y%m%d') + tfrun wget "https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev${DATE}-py3-none-any.whl" -O libtpu.whl + # -j to discard intermediate directories; -o to overwrite if exists; -d to set output dir + tfrun unzip libtpu.whl libtpu/libtpu.so -j -o -d ./tensorflow/lib +fi tfrun bazel "${TFCI_BAZEL_BAZELRC_ARGS[@]}" build "${TFCI_BAZEL_COMMON_ARGS[@]}" //tensorflow/tools/pip_package:build_pip_package tfrun ./bazel-bin/tensorflow/tools/pip_package/build_pip_package "$TFCI_OUTPUT_DIR" "${TFCI_BUILD_PIP_PACKAGE_ARGS[@]}" From 860fdd1dad2e1f95b7f47d59cbc5849988a1d19e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 19:09:44 -0700 Subject: [PATCH 031/246] Use random input for numerical tests in convert_tf_quant_to_mhlo_int_test This CL adds helper functions to generate random input and compare test results. I found that TF kernels and the lowering pass have different round schemes for quantize/rescaling etc: floor(x+0.5) vs round_nearest_even. So there maybe +/-1 errors. The lowering pass uses the latter, which is consistent with TF quantizer. (Requantize in TF kernel doesn't use the former, thus the good agreement). Therefore I removed most Q/DQ pairs in test cases so that they don't interfere with the evaluation of other ops. PiperOrigin-RevId: 575992943 --- .../mlir/quantization/stablehlo/BUILD | 3 + .../convert_tf_quant_to_mhlo_int_test.cc | 485 ++++++++++-------- 2 files changed, 263 insertions(+), 225 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index 32094e7f033fc6..86e942097b4695 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -238,9 +238,12 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/compiler/tf2xla:common", "//tensorflow/core:framework", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:nn", "//tensorflow/core/kernels/uniform_quant_ops:kernels", "//tensorflow/core/ops", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/random", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index 468e5518abebfc..434e2fa6516f3f 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include #include @@ -21,6 +22,7 @@ limitations under the License. #include #include "absl/log/check.h" +#include "absl/random/random.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -188,288 +190,321 @@ class ConvertTfQuantToMhloIntTest : public ::testing::Test { void ExecuteAndCompareResultsWithTfKernel( absl::string_view program, absl::Span arguments, - float error_tolerance = 0.1) { - TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(program)); + std::optional tf_program = std::nullopt, + double error_tolerance = 0.1) { + // Expected result is calculated by evaluating using TF kernels. In some + // cases, TF kernel behaves differently from lowered graph (e.g. Hybrid + // ops). So we optionally use a different graph to calculate the expected + // result. + TF_ASSERT_OK_AND_ASSIGN( + auto expected, + this->EvaluateTfFunction( + (tf_program.has_value() ? *tf_program : program), arguments)); + TF_ASSERT_OK_AND_ASSIGN(auto executable, this->CompileProgram(program)); TF_ASSERT_OK_AND_ASSIGN( - auto result_literal, + auto result, this->ExecuteProgramAndReturnSingleResult(executable.get(), arguments)); - TF_ASSERT_OK_AND_ASSIGN(auto expected, - this->EvaluateTfFunction(program, arguments)); - EXPECT_TRUE(xla::LiteralTestUtil::Near(*expected, *result_literal, + // Convert to double for comparison. This is needed for comparing integers + // since it LiteralTestUtil asserts different integers even if it is within + // error_spec. + TF_ASSERT_OK_AND_ASSIGN(auto expected_double, expected->Convert(xla::F64)) + TF_ASSERT_OK_AND_ASSIGN(auto result_double, result->Convert(xla::F64)) + EXPECT_TRUE(xla::LiteralTestUtil::Near(expected_double, result_double, xla::ErrorSpec(error_tolerance))); } + absl::StatusOr CreateRandomF32Literal( + absl::Span dims, float min = -100, float max = 100) { + TF_ASSIGN_OR_RETURN(auto shape, + xla::ShapeUtil::MakeValidatedShape(xla::F32, dims)); + return xla::LiteralUtil::CreateLiteralWithGenerator( + shape, [this, min, max](absl::Span dims) -> float { + return absl::Uniform(bitgen_, min, max); + }); + } + + absl::StatusOr CreateRandomI8Literal( + absl::Span dims, int8_t min = -128, int8_t max = 127) { + TF_ASSIGN_OR_RETURN(auto shape, + xla::ShapeUtil::MakeValidatedShape(xla::S8, dims)); + return xla::LiteralUtil::CreateLiteralWithGenerator( + shape, [this, min, max](absl::Span dims) -> int8_t { + return absl::Uniform(bitgen_, min, max); + }); + } + + absl::StatusOr CreateRandomI32Literal( + absl::Span dims, int32_t min = -128, int32_t max = 127) { + TF_ASSIGN_OR_RETURN(auto shape, + xla::ShapeUtil::MakeValidatedShape(xla::S32, dims)); + return xla::LiteralUtil::CreateLiteralWithGenerator( + shape, [this, min, max](absl::Span dims) -> int32_t { + return absl::Uniform(bitgen_, min, max); + }); + } + std::unique_ptr ctx_; std::unique_ptr pjrt_client_; xla::PjRtDevice* device_; + absl::BitGen bitgen_; }; TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAndDequantize) { constexpr absl::string_view kProgram = R"mlir( -func.func @main(%arg0: tensor<4xf32>) -> tensor<4xf32> { - %scale = "tf.Const"() { value = dense<10.0> : tensor } : () - -> tensor +func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> { + %scale = "tf.Const"() { value = dense<0.347> : tensor } : () -> tensor %zp = "tf.Const"() { value = dense<3> : tensor } : () -> tensor %0 = "tf.UniformQuantize"(%arg0, %scale, %zp) { quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 - } : (tensor<4xf32>, tensor, tensor) -> tensor<4x!tf_type.qint8> + } : (tensor<10xf32>, tensor, tensor) -> tensor<10x!tf_type.qint8> %1 = "tf.UniformDequantize"(%0, %scale, %zp) { quantization_axis = -1 : i64, quantization_min_val = -128 : i64, quantization_max_val = 127 : i64 - } : (tensor<4x!tf_type.qint8>, tensor, tensor) -> tensor<4xf32> - return %1 : tensor<4xf32> + } : (tensor<10x!tf_type.qint8>, tensor, tensor) -> tensor<10xf32> + return %1 : tensor<10xf32> })mlir"; - auto arg0 = - xla::LiteralUtil::CreateR1({100.0f, 20000.0f, -2409.0f, -25.1f}); - ExecuteAndCompareResultsWithTfKernel(kProgram, {&arg0}); + TF_ASSERT_OK_AND_ASSIGN(auto arg0, CreateRandomF32Literal({10})); + // error_tolerance is set to be slightly > scale because different rounding + // implementations for UniformQuantize in TF kernel and the lowering passes + // may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel( + kProgram, {&arg0}, /*tf_program=*/std::nullopt, /*error_tolerance=*/0.35); } TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolution) { constexpr absl::string_view kProgram = R"mlir( -func.func @main(%input: tensor<1x2x2x1xf32>, %filter: tensor<2x1x1x1xf32>) -> tensor<1x2x2x1xf32> { - %input_scale = "tf.Const"() { value = dense<7.3> : tensor } : () - -> tensor - %input_zp = "tf.Const"() { value = dense<-45> : tensor } : () -> tensor - %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () - -> tensor - %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %accum_scale = "tf.Const"() { value = dense<0.3431> : tensor } : () - -> tensor - %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %quant_input = "tf.UniformQuantize"(%input, %input_scale, %input_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", - attr_map = "", quantization_axis = -1 : i64, quantization_max_val = 127 : i64, - quantization_min_val = -128 : i64 - } : (tensor<1x2x2x1xf32>, tensor, tensor) -> tensor<1x2x2x1x!tf_type.qint8> - %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", - attr_map = "", quantization_axis = -1 : i64, - quantization_max_val = 127 : i64, quantization_min_val = -128 : i64 - } : (tensor<2x1x1x1xf32>, tensor, tensor) -> tensor<2x1x1x1x!tf_type.qint8> - %0 = "tf.UniformQuantizedConvolution"( - %quant_input, %quant_filter, %input_scale, %input_zp, - %filter_scale, %filter_zp, %accum_scale, %accum_zp - ) { - Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", - attr_map = "", batch_group_count = 1 : i64, - dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", - explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1], - lhs_quantization_axis = -1 : i64, lhs_quantization_max_val = 127 : i64, - lhs_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64, - output_quantization_max_val = 2147483647 : i64, - output_quantization_min_val = -2147483648 : i64, padding = "SAME", - rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64, - rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64, - window_strides = [1, 1] - } : (tensor<1x2x2x1x!tf_type.qint8>, tensor<2x1x1x1x!tf_type.qint8>, - tensor, tensor, tensor, tensor, tensor, tensor - ) -> tensor<1x2x2x1x!tf_type.qint32> - %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) { - quantization_axis = -1 : i64, quantization_min_val = -128 : i64, - quantization_max_val = 127 : i64 - } : (tensor<1x2x2x1x!tf_type.qint32>, tensor, tensor) -> tensor<1x2x2x1xf32> - return %output : tensor<1x2x2x1xf32> +func.func @main(%input: tensor<1x9x9x9xi8>, %filter: tensor<3x3x9x10xi8>) -> tensor<1x9x9x10xi32> { + %input_scale = "tf.Const"() { value = dense<2.0> : tensor } : () -> tensor + %input_zp = "tf.Const"() { value = dense<-10> : tensor } : () -> tensor + %filter_scale = "tf.Const"() { value = dense<0.5> : tensor } : () -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %accum_scale = "tf.Const"() { value = dense<1.0> : tensor } : () -> tensor + %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_input = "tf.Cast"(%input) {} : (tensor<1x9x9x9xi8>) -> + tensor<1x9x9x9x!tf_type.qint8> + %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x9x10xi8>) -> + tensor<3x3x9x10x!tf_type.qint8> + %0 = "tf.UniformQuantizedConvolution"( + %quant_input, %quant_filter, %input_scale, %input_zp, + %filter_scale, %filter_zp, %accum_scale, %accum_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", + attr_map = "", batch_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1], + lhs_quantization_axis = -1 : i64, lhs_quantization_max_val = 127 : i64, + lhs_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64, + output_quantization_max_val = 2147483647 : i64, + output_quantization_min_val = -2147483648 : i64, padding = "SAME", + rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64, + window_strides = [1, 1] + } : (tensor<1x9x9x9x!tf_type.qint8>, tensor<3x3x9x10x!tf_type.qint8>, + tensor, tensor, tensor, tensor, tensor, tensor + ) -> tensor<1x9x9x10x!tf_type.qint32> + %output = "tf.Cast"(%0) {} : (tensor<1x9x9x10x!tf_type.qint32>) -> tensor<1x9x9x10xi32> + return %output : tensor<1x9x9x10xi32> })mlir"; - auto input = xla::LiteralUtil::CreateR4( - {{{{14.f}, {-100.f}}, {{-200.f}, {350.f}}}}); - auto filter = xla::LiteralUtil::CreateR4({{{{4.1f}}}, {{{-2.f}}}}); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({1, 9, 9, 9})); + TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({3, 3, 9, 10})); ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeConvolutionHybrid) { + constexpr absl::string_view kTfProgram = R"mlir( +func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> { + %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x10x20xi8>) -> + tensor<3x3x10x20x!tf_type.qint8> + %filter_new = "tf.UniformDequantize"(%quant_filter, %filter_scale, %filter_zp) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, + quantization_max_val = 127 : i64 + } : ( + tensor<3x3x10x20x!tf_type.qint8>, tensor, tensor + ) -> tensor<3x3x10x20xf32> + %0 = "tf.Conv2D"(%input, %filter_new) { + Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_FLOAT", + attr_map = "", batch_group_count = 1 : i64, + explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1], + padding = "SAME", rhs_dilation = [1, 1], strides = [1, 1, 1, 1] + } : (tensor<2x10x10x10xf32>, tensor<3x3x10x20xf32>) -> tensor<2x10x10x20xf32> + return %0 : tensor<2x10x10x20xf32> +})mlir"; constexpr absl::string_view kProgram = R"mlir( -func.func @main(%input: tensor<1x2x2x1xf32>, %filter: tensor<2x1x1x1xf32>) -> tensor<1x2x2x1xf32> { - %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () - -> tensor - %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", - attr_map = "", quantization_axis = -1 : i64, - quantization_max_val = 127 : i64, quantization_min_val = -128 : i64 - } : (tensor<2x1x1x1xf32>, tensor, tensor) -> tensor<2x1x1x1x!tf_type.qint8> - %0 = "tf.UniformQuantizedConvolutionHybrid"( - %input, %quant_filter, %filter_scale, %filter_zp - ) { - Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", - attr_map = "", batch_group_count = 1 : i64, - dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", - explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1], - padding = "SAME", rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64, - rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64, - window_strides = [1, 1] - } : (tensor<1x2x2x1xf32>, tensor<2x1x1x1x!tf_type.qint8>, - tensor, tensor) -> tensor<1x2x2x1xf32> - return %0 : tensor<1x2x2x1xf32> +func.func @main(%input: tensor<2x10x10x10xf32>, %filter: tensor<3x3x10x20xi8>) -> tensor<2x10x10x20xf32> { + %filter_scale = "tf.Const"() { value = dense<0.047> : tensor } : () -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_filter = "tf.Cast"(%filter) {} : (tensor<3x3x10x20xi8>) -> tensor<3x3x10x20x!tf_type.qint8> + %0 = "tf.UniformQuantizedConvolutionHybrid"( + %input, %quant_filter, %filter_scale, %filter_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", + attr_map = "", batch_group_count = 1 : i64, + dimension_numbers = "\10\03\1A\02\01\02 \02(\032\02\00\01@\03J\02\01\02", + explicit_padding = [], feature_group_count = 1 : i64, lhs_dilation = [1, 1], + padding = "SAME", rhs_dilation = [1, 1], rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64, + window_strides = [1, 1] + } : (tensor<2x10x10x10xf32>, tensor<3x3x10x20x!tf_type.qint8>, + tensor, tensor) -> tensor<2x10x10x20xf32> + return %0 : tensor<2x10x10x20xf32> })mlir"; - auto input = xla::LiteralUtil::CreateR4( - {{{{14.f}, {-100.f}}, {{-200.f}, {350.f}}}}); - auto filter = xla::LiteralUtil::CreateR4({{{{4.1f}}}, {{{-2.f}}}}); - // The large tolerance here is expected because - // tf.UniformQuantizedConvolutionHybrid does DRQ. But StableHLO hybrid ops - // does weight-only. - ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, - /*error_tolerance=*/5.0); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomF32Literal({2, 10, 10, 10})); + TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({3, 3, 10, 20})); + // TF kernels for UniformQuantizedConvolutionHybrid does DRQ. But StableHLO + // hybrid ops does weight-only. So we use a different TF graph for evaluating + // expected weight-only quantized results. + ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDot) { constexpr absl::string_view kProgram = R"mlir( -func.func @main(%input: tensor<1x2xf32>, %filter: tensor<2x3xf32>) -> tensor<1x3xf32> { - %input_scale = "tf.Const"() { value = dense<0.588> : tensor } : () - -> tensor - %input_zp = "tf.Const"() { value = dense<42> : tensor } : () -> tensor - %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () - -> tensor - %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %accum_scale = "tf.Const"() { value = dense<0.013818> : tensor } : () - -> tensor - %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %quant_input = "tf.UniformQuantize"(%input, %input_scale, %input_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 127 : i64, - quantization_min_val = -128 : i64 - } : (tensor<1x2xf32>, tensor, tensor) -> tensor<1x2x!tf_type.qint8> - %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 127 : i64, - quantization_min_val = -128 : i64 - } : (tensor<2x3xf32>, tensor, tensor) -> tensor<2x3x!tf_type.qint8> - %0 = "tf.UniformQuantizedDot"( - %quant_input, %quant_filter, %input_scale, %input_zp, %filter_scale, - %filter_zp, %accum_scale, %accum_zp - ) { - Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", attr_map = "", - device = "", lhs_quantization_axis = -1 : i64, - lhs_quantization_max_val = 127 : i64, lhs_quantization_min_val = -128 : i64, - output_quantization_axis = -1 : i64, output_quantization_max_val = 2147483647 : i64, - output_quantization_min_val = -2147483648 : i64, rhs_quantization_axis = -1 : i64, - rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64 - } : ( - tensor<1x2x!tf_type.qint8>, tensor<2x3x!tf_type.qint8>, tensor, - tensor, tensor, tensor, tensor, tensor - ) -> tensor<1x3x!tf_type.qint32> - %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) { - quantization_axis = -1 : i64, quantization_min_val = -128 : i64, - quantization_max_val = 127 : i64 - } : (tensor<1x3x!tf_type.qint32>, tensor, tensor) -> tensor<1x3xf32> - return %output : tensor<1x3xf32> +func.func @main(%input: tensor<8x9xi8>, %filter: tensor<9x10xi8>) -> tensor<8x10xi32> { + %input_scale = "tf.Const"() { value = dense<0.588> : tensor } : () -> tensor + %input_zp = "tf.Const"() { value = dense<42> : tensor } : () -> tensor + %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %accum_scale = "tf.Const"() { value = dense<0.013818> : tensor } : () -> tensor + %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_input = "tf.Cast"(%input) {} : (tensor<8x9xi8>) -> tensor<8x9x!tf_type.qint8> + %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8> + %0 = "tf.UniformQuantizedDot"( + %quant_input, %quant_filter, %input_scale, %input_zp, %filter_scale, + %filter_zp, %accum_scale, %accum_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT32", attr_map = "", + device = "", lhs_quantization_axis = -1 : i64, + lhs_quantization_max_val = 127 : i64, + lhs_quantization_min_val = -128 : i64, + output_quantization_axis = -1 : i64, + output_quantization_max_val = 2147483647 : i64, + output_quantization_min_val = -2147483648 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 127 : i64, + rhs_quantization_min_val = -128 : i64 + } : ( + tensor<8x9x!tf_type.qint8>, tensor<9x10x!tf_type.qint8>, tensor, + tensor, tensor, tensor, tensor, tensor + ) -> tensor<8x10x!tf_type.qint32> + %output = "tf.Cast"(%0) {} : (tensor<8x10x!tf_type.qint32>) -> tensor<8x10xi32> + return %output : tensor<8x10xi32> })mlir"; - auto input = xla::LiteralUtil::CreateR2({{50.f, -100.f}}); - auto filter = - xla::LiteralUtil::CreateR2({{1.f, 2.f, 3.f}, {-1.f, -3.f, 1.f}}); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({8, 9})); + TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({9, 10})); ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); } TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeDotHybrid) { + constexpr absl::string_view kTfProgram = R"mlir( +func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> { + %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8> + %filter_new = "tf.UniformDequantize"(%quant_filter, %filter_scale, %filter_zp) { + quantization_axis = -1 : i64, quantization_min_val = -128 : i64, + quantization_max_val = 127 : i64 + } : (tensor<9x10x!tf_type.qint8>, tensor, tensor) -> tensor<9x10xf32> + %0 = "tf.MatMul"(%input, %filter_new) { + } : (tensor<8x9xf32>, tensor<9x10xf32>) -> tensor<8x10xf32> + return %0 : tensor<8x10xf32> +})mlir"; constexpr absl::string_view kProgram = R"mlir( -func.func @main(%input: tensor<1x2xf32>, %filter: tensor<2x3xf32>) -> tensor<1x3xf32> { - %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () - -> tensor - %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %quant_filter = "tf.UniformQuantize"(%filter, %filter_scale, %filter_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 127 : i64, - quantization_min_val = -128 : i64 - } : (tensor<2x3xf32>, tensor, tensor) -> tensor<2x3x!tf_type.qint8> - %0 = "tf.UniformQuantizedDotHybrid"( - %input, %quant_filter, %filter_scale, %filter_zp - ) { - Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", attr_map = "", - device = "", rhs_quantization_axis = -1 : i64, - rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64 - } : (tensor<1x2xf32>, tensor<2x3x!tf_type.qint8>, tensor, tensor) -> tensor<1x3xf32> - return %0 : tensor<1x3xf32> +func.func @main(%input: tensor<8x9xf32>, %filter: tensor<9x10xi8>) -> tensor<8x10xf32> { + %filter_scale = "tf.Const"() { value = dense<0.0235> : tensor } : () + -> tensor + %filter_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_filter = "tf.Cast"(%filter) {} : (tensor<9x10xi8>) -> tensor<9x10x!tf_type.qint8> + %0 = "tf.UniformQuantizedDotHybrid"( + %input, %quant_filter, %filter_scale, %filter_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_FLOAT", attr_map = "", + device = "", rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 127 : i64, rhs_quantization_min_val = -128 : i64 + } : (tensor<8x9xf32>, tensor<9x10x!tf_type.qint8>, tensor, tensor) -> tensor<8x10xf32> + return %0 : tensor<8x10xf32> })mlir"; - auto input = xla::LiteralUtil::CreateR2({{50.f, -100.f}}); - auto filter = - xla::LiteralUtil::CreateR2({{1.f, 2.f, 3.f}, {-1.f, -3.f, 1.f}}); - ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomF32Literal({8, 9})); + TF_ASSERT_OK_AND_ASSIGN(auto filter, CreateRandomI8Literal({9, 10})); + // TF kernels for UniformQuantizedDotHybrid does DRQ. But StableHLO hybrid ops + // does weight-only. So we use a different TF graph for evaluating expected + // weight-only quantized results. + ExecuteAndCompareResultsWithTfKernel(kProgram, {&input, &filter}, kTfProgram); } TEST_F(ConvertTfQuantToMhloIntTest, UniformRequantize) { constexpr absl::string_view kProgram = R"mlir( -func.func @main(%input: tensor<4xf32>) -> tensor<4xf32> { - %input_scale = "tf.Const"() { value = dense<0.2235> : tensor } : () - -> tensor - %input_zp = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor - %output_scale = "tf.Const"() { value = dense<0.11> : tensor } : () - -> tensor - %output_zp = "tf.Const"() { value = dense<3> : tensor } : () -> tensor - %0 = "tf.UniformQuantize"(%input, %input_scale, %input_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT8", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 127 : i64, - quantization_min_val = -128 : i64 - } : (tensor<4xf32>, tensor, tensor) -> tensor<4x!tf_type.qint8> - %1 = "tf.UniformRequantize"( - %0, %input_scale, %input_zp, %output_scale, %output_zp - ) { - Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", - device = "", input_quantization_axis = -1, input_quantization_max_val = 127 : i64, - input_quantization_min_val = -128 : i64, output_quantization_axis = -1 : i64, - output_quantization_max_val = 127 : i64, output_quantization_min_val = -128 : i64 - } : ( - tensor<4x!tf_type.qint8>, tensor, tensor, tensor, tensor - ) -> tensor<4x!tf_type.qint8> - %2 = "tf.UniformDequantize"(%1, %output_scale, %output_zp) { - quantization_axis = -1 : i64, - quantization_min_val = -128 : i64, - quantization_max_val = 127 : i64 - } : (tensor<4x!tf_type.qint8>, tensor, tensor) -> tensor<4xf32> - return %2 : tensor<4xf32> +func.func @main(%input: tensor<10xi8>) -> tensor<10xi8> { + %input_scale = "tf.Const"() { value = dense<0.2235> : tensor } : () -> tensor + %input_zp = "tf.Const"() { value = dense<-2> : tensor } : () -> tensor + %output_scale = "tf.Const"() { value = dense<0.11> : tensor } : () -> tensor + %output_zp = "tf.Const"() { value = dense<3> : tensor } : () -> tensor + %0 = "tf.Cast"(%input) {} : (tensor<10xi8>) -> tensor<10x!tf_type.qint8> + %1 = "tf.UniformRequantize"( + %0, %input_scale, %input_zp, %output_scale, %output_zp + ) { + Tin = "tfdtype$DT_QINT8", Tout = "tfdtype$DT_QINT8", attr_map = "", + device = "", input_quantization_axis = -1, + input_quantization_max_val = 127 : i64, + input_quantization_min_val = -128 : i64, + output_quantization_axis = -1 : i64, + output_quantization_max_val = 127 : i64, + output_quantization_min_val = -128 : i64 + } : ( + tensor<10x!tf_type.qint8>, tensor, tensor, tensor, + tensor + ) -> tensor<10x!tf_type.qint8> + %2 = "tf.Cast"(%1) {} : (tensor<10x!tf_type.qint8>) -> tensor<10xi8> + return %2 : tensor<10xi8> })mlir"; - auto input = xla::LiteralUtil::CreateR1({3.f, -10.f, -2.1f, 8.7f}); + TF_ASSERT_OK_AND_ASSIGN(auto input, CreateRandomI8Literal({10})); ExecuteAndCompareResultsWithTfKernel(kProgram, {&input}); } TEST_F(ConvertTfQuantToMhloIntTest, UniformQuantizeAdd) { constexpr absl::string_view kProgram = R"mlir( -func.func @main(%lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { - %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor } : () - -> tensor - %lhs_zp = "tf.Const"() { value = dense<42> : tensor } : () -> tensor - %rhs_scale = "tf.Const"() { value = dense<0.0239> : tensor } : () - -> tensor - %rhs_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %accum_scale = "tf.Const"() { value = dense<0.013> : tensor } : () - -> tensor - %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor - %quant_lhs = "tf.UniformQuantize"(%lhs, %lhs_scale, %lhs_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT32", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, - quantization_min_val = -2147483648 : i64 - } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint32> - %quant_rhs = "tf.UniformQuantize"(%rhs, %rhs_scale, %rhs_zp) { - Tin = "tfdtype$DT_FLOAT", Tout = "tfdtype$DT_QINT32", attr_map = "", - quantization_axis = -1 : i64, quantization_max_val = 2147483647 : i64, - quantization_min_val = -2147483648 : i64 - } : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2x!tf_type.qint32> - %0 = "tf.UniformQuantizedAdd"( - %quant_lhs, %quant_rhs, %lhs_scale, %lhs_zp, %rhs_scale, - %rhs_zp, %accum_scale, %accum_zp - ) { - Tin = "tfdtype$DT_QINT32", Tout = "tfdtype$DT_QINT32", attr_map = "", - device = "", lhs_quantization_axis = -1 : i64, - lhs_quantization_max_val = 2147483647 : i64, lhs_quantization_min_val = -2147483648 : i64, - output_quantization_axis = -1 : i64, output_quantization_max_val = 2147483647 : i64, - output_quantization_min_val = -2147483648 : i64, rhs_quantization_axis = -1 : i64, - rhs_quantization_max_val = 2147483647 : i64, rhs_quantization_min_val = -2147483648 : i64 - } : ( - tensor<2x2x!tf_type.qint32>, tensor<2x2x!tf_type.qint32>, tensor, - tensor, tensor, tensor, tensor, tensor - ) -> tensor<2x2x!tf_type.qint32> - %output = "tf.UniformDequantize"(%0, %accum_scale, %accum_zp) { - quantization_axis = -1 : i64, quantization_min_val = -128 : i64, - quantization_max_val = 127 : i64 - } : (tensor<2x2x!tf_type.qint32>, tensor, tensor) -> tensor<2x2xf32> - return %output : tensor<2x2xf32> +func.func @main(%lhs: tensor<10x10xi32>, %rhs: tensor<10x10xi32>) -> tensor<10x10xi32> { + %lhs_scale = "tf.Const"() { value = dense<0.518> : tensor } : () -> tensor + %lhs_zp = "tf.Const"() { value = dense<42> : tensor } : () -> tensor + %rhs_scale = "tf.Const"() { value = dense<0.0239> : tensor } : () -> tensor + %rhs_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %accum_scale = "tf.Const"() { value = dense<0.013> : tensor } : () -> tensor + %accum_zp = "tf.Const"() { value = dense<0> : tensor } : () -> tensor + %quant_lhs = "tf.Cast"(%lhs) {} : (tensor<10x10xi32>) -> tensor<10x10x!tf_type.qint32> + %quant_rhs = "tf.Cast"(%rhs) {} : (tensor<10x10xi32>) -> tensor<10x10x!tf_type.qint32> + %0 = "tf.UniformQuantizedAdd"( + %quant_lhs, %quant_rhs, %lhs_scale, %lhs_zp, %rhs_scale, + %rhs_zp, %accum_scale, %accum_zp + ) { + Tin = "tfdtype$DT_QINT32", Tout = "tfdtype$DT_QINT32", attr_map = "", + device = "", lhs_quantization_axis = -1 : i64, + lhs_quantization_max_val = 2147483647 : i64, + lhs_quantization_min_val = -2147483648 : i64, + output_quantization_axis = -1 : i64, + output_quantization_max_val = 2147483647 : i64, + output_quantization_min_val = -2147483648 : i64, + rhs_quantization_axis = -1 : i64, + rhs_quantization_max_val = 2147483647 : i64, + rhs_quantization_min_val = -2147483648 : i64 + } : ( + tensor<10x10x!tf_type.qint32>, tensor<10x10x!tf_type.qint32>, tensor, + tensor, tensor, tensor, tensor, tensor + ) -> tensor<10x10x!tf_type.qint32> + %1 = "tf.Cast"(%0) {} : (tensor<10x10x!tf_type.qint32>) -> tensor<10x10xi32> + return %1 : tensor<10x10xi32> })mlir"; - auto lhs = xla::LiteralUtil::CreateR2({{23.f, 12.f}, {-10.f, -3.f}}); - auto rhs = xla::LiteralUtil::CreateR2({{1.f, 2.f}, {-1.f, -3.f}}); - ExecuteAndCompareResultsWithTfKernel(kProgram, {&lhs, &rhs}); + TF_ASSERT_OK_AND_ASSIGN(auto lhs, CreateRandomI32Literal({10, 10})); + TF_ASSERT_OK_AND_ASSIGN(auto rhs, CreateRandomI32Literal({10, 10})); + // error_tolerance is set to be 1 because different rounding implementations + // in TF kernel and the lowering passes may cause +/-1 differences. + ExecuteAndCompareResultsWithTfKernel(kProgram, {&lhs, &rhs}, + /*tf_program=*/std::nullopt, + /*error_tolerance=*/1.0); } } // namespace From ddab1a76974f56b68af2bcdfce96c68e7a784efc Mon Sep 17 00:00:00 2001 From: Thai Nguyen Date: Mon, 23 Oct 2023 20:24:06 -0700 Subject: [PATCH 032/246] Add ResourceGather to the list of quantizable ops PiperOrigin-RevId: 576004858 --- .../mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index d415e755e9c6f6..22f95fa5369215 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -45,8 +45,8 @@ bool IsOpWithQuantizableTrait(Operation* op) { // Supported quantizable ops. return isa(op); + TF::ResourceGatherOp, TF::DepthwiseConv2dNativeOp, TF::Conv3DOp, + TF::BatchMatMulV2Op, TF::EinsumOp>(op); } bool IsOpWithInt8TypeOperand(Operation* op) { From 3a0645b863a154db0711f61bdedd2d20bb423a05 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 23 Oct 2023 23:11:49 -0700 Subject: [PATCH 033/246] Internal Code Change PiperOrigin-RevId: 576031178 --- tensorflow/lite/delegates/flex/test/BUILD | 3 --- 1 file changed, 3 deletions(-) diff --git a/tensorflow/lite/delegates/flex/test/BUILD b/tensorflow/lite/delegates/flex/test/BUILD index e0bb3f4dba63c9..238b6adadd824c 100644 --- a/tensorflow/lite/delegates/flex/test/BUILD +++ b/tensorflow/lite/delegates/flex/test/BUILD @@ -8,9 +8,6 @@ load("@bazel_skylib//rules:build_test.bzl", "build_test") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], - default_visibility = [ - "//tensorflow/lite/android:__subpackages__", - ], licenses = ["notice"], ) From 889a3ad60665ad8e233674fd87e0fbcebc7d38e7 Mon Sep 17 00:00:00 2001 From: Alina Sbirlea Date: Mon, 23 Oct 2023 23:34:05 -0700 Subject: [PATCH 034/246] Integrate LLVM at llvm/llvm-project@e214bdac51c4 Updates LLVM usage to match [e214bdac51c4](https://github.com/llvm/llvm-project/commit/e214bdac51c4) PiperOrigin-RevId: 576035284 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index 8be65a702ab288..de644ff01f52d2 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "e558be51bab051d1471d92e967f8a2aecc13567a" - LLVM_SHA256 = "94069d8ccbab6451c7b070f26d926c310cdba17481ebc997273d95e0c82d86f8" + LLVM_COMMIT = "e214bdac51c46b554e1fb99de6b6b6c735b75bf1" + LLVM_SHA256 = "0da83d809440c4d70061656ad5868ddc60c522fc5d64e08c36118416642ffc43" tf_http_archive( name = name, From d82247fe04641229cb9a3c18d1a1ce96453ab282 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2023 01:31:28 -0700 Subject: [PATCH 035/246] Update TFRT dependency to use revision http://github.com/tensorflow/runtime/commit/2d980b4b1e9a945eddd53524470663fa119b0204. PiperOrigin-RevId: 576057550 --- third_party/tf_runtime/workspace.bzl | 4 ++-- third_party/xla/third_party/tf_runtime/workspace.bzl | 4 ++-- .../xla/third_party/tsl/third_party/tf_runtime/workspace.bzl | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/third_party/tf_runtime/workspace.bzl b/third_party/tf_runtime/workspace.bzl index 0b3107e4c868ad..c642fd045acccd 100644 --- a/third_party/tf_runtime/workspace.bzl +++ b/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" - TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" + TFRT_COMMIT = "2d980b4b1e9a945eddd53524470663fa119b0204" + TFRT_SHA256 = "a7a76fb20a292d943feed7a73e6222951e9f81b040681f4b5af703ac7b3198a2" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tf_runtime/workspace.bzl index 0b3107e4c868ad..c642fd045acccd 100644 --- a/third_party/xla/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" - TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" + TFRT_COMMIT = "2d980b4b1e9a945eddd53524470663fa119b0204" + TFRT_SHA256 = "a7a76fb20a292d943feed7a73e6222951e9f81b040681f4b5af703ac7b3198a2" tf_http_archive( name = "tf_runtime", diff --git a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl index 0b3107e4c868ad..c642fd045acccd 100644 --- a/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl +++ b/third_party/xla/third_party/tsl/third_party/tf_runtime/workspace.bzl @@ -6,8 +6,8 @@ def repo(): """Imports TFRT.""" # Attention: tools parse and update these lines. - TFRT_COMMIT = "18a68a86884cc3b033586e725cb65557420fe3ea" - TFRT_SHA256 = "451004d621e595c31339345165d371d8f7109b4f249bd8f2690fd10a3491d5c0" + TFRT_COMMIT = "2d980b4b1e9a945eddd53524470663fa119b0204" + TFRT_SHA256 = "a7a76fb20a292d943feed7a73e6222951e9f81b040681f4b5af703ac7b3198a2" tf_http_archive( name = "tf_runtime", From 8630f297a375868487875e121b3e627d34923df3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2023 02:02:04 -0700 Subject: [PATCH 036/246] compat: Update forward compatibility horizon to 2023-10-24 PiperOrigin-RevId: 576064029 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 323ff2fc0f436b..30b369380d5baa 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -29,7 +29,7 @@ # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 10, 23) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2023, 10, 24) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From 2cf8b1c62a98c859bbe2ae69160680eea6aae160 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2023 02:02:07 -0700 Subject: [PATCH 037/246] Update GraphDef version to 1659. PiperOrigin-RevId: 576064045 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 2da8be25e6461e..ecac6c25de259d 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 1658 // Updated: 2023/10/23 +#define TF_GRAPH_DEF_VERSION 1659 // Updated: 2023/10/24 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From 85ac1c6ddc93d4f53ff5b2c5c1c7bac7a8a44030 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 24 Oct 2023 02:18:54 -0700 Subject: [PATCH 038/246] Allow reduction users in multi-output fusions with buffer aliasing (FusionCanShareBufferHint) PiperOrigin-RevId: 576067660 --- third_party/xla/xla/service/gpu/BUILD | 3 ++ .../xla/xla/service/gpu/buffer_sharing.cc | 44 +++++++++++++++---- .../service/gpu/gpu_copy_insertion_test.cc | 40 +++++++++++++++-- 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 8f97780283b215..6804f72bae10e7 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -3361,9 +3361,12 @@ cc_library( deps = [ ":backend_configs_cc", ":cublas_cudnn", + ":hlo_fusion_analysis", ":ir_emission_utils", "//xla:shape_util", "//xla/hlo/ir:hlo", + "//xla/stream_executor:device_description", + "//xla/stream_executor:device_description_proto_cc", "@com_google_absl//absl/container:flat_hash_set", ], ) diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc index d02cbc7a27e179..64421596dcbd60 100644 --- a/third_party/xla/xla/service/gpu/buffer_sharing.cc +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -21,13 +21,18 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_casting_utils.h" #include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/gpu/backend_configs.pb.h" #include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/hlo_fusion_analysis.h" #include "xla/service/gpu/ir_emission_utils.h" #include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/stream_executor/device_description.h" +#include "xla/stream_executor/device_description.pb.h" namespace xla { namespace gpu { @@ -35,7 +40,8 @@ namespace gpu { std::optional FusionCanShareBufferHint(const HloInstruction* user, const HloInstruction* operand, const ShapeIndex& user_index) { - if (user->opcode() != HloOpcode::kFusion) { + const HloFusionInstruction* fusion = DynCast(user); + if (fusion == nullptr) { return std::nullopt; } @@ -65,10 +71,21 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, } } + // Allow multiple output users, if they end in reductions. + // This only works for the reduction emitter, as it calculates the reduction + // first, i.e. before processing other outputs (that may overwrite the input). + stream_executor::GpuDeviceInfoProto device_info; + stream_executor::DeviceDescription device_description(device_info); + auto analysis = HloFusionAnalysis::Create(fusion, &device_description); + bool is_reduction_emitter = analysis->GetEmitterFusionKind() == + HloFusionAnalysis::EmitterFusionKind::kReduction; + const HloInstruction* reduction_hero = + is_reduction_emitter ? reduction_hero = analysis->FindHeroReduction() + : nullptr; + // We need to make sure that the fusion parameter is accessed in the same - // iteration order as the fusion output. Also, there should not be two fusion - // outputs that consume the fusion parameter, because we do not want to share - // the same fusion operand with two different fusion outputs. To make sure + // iteration order as the fusion output. Also, there should not be any other + // fusion output that accesses it in a different iteration order. To make sure // that the iteration order is the same, we only allow ops on the path from // fusion parameter to fusion output which are elementwise (no copy) or // bitcast or an elementwise dynamic update slice (i.e. with the first operand @@ -88,16 +105,21 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, q.push(fusion_param); visited.insert(fusion_param); bool found_path_to_output = false; + int reached_root = 0; while (!q.empty()) { HloInstruction* hlo_operand = q.front(); q.pop(); if (hlo_operand == output) { found_path_to_output = true; - // The output should have at most 1 user: the tuple op (in case of a - // multi-output fusion) - if (hlo_operand->user_count() > 1) { + // We still need to process the users of 'hlo_operand'. There can be other + // reduction users in addition to the tuple user. + if (hlo_operand->user_count() > 1 && !is_reduction_emitter) { return false; } + } + // Reduction emitter processes the reduction first, so the values below it + // will not interfere with buffer sharing. + if (hlo_operand == reduction_hero) { continue; } for (HloInstruction* hlo : hlo_operand->users()) { @@ -134,7 +156,8 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, } else if ((!hlo->IsElementwiseOnOperand( hlo->operand_index(hlo_operand)) || hlo->opcode() == HloOpcode::kCopy) && - hlo->opcode() != HloOpcode::kBitcast) { + hlo->opcode() != HloOpcode::kBitcast && + hlo->opcode() != HloOpcode::kTuple && hlo != reduction_hero) { // This check also catches the case that we reach a different fusion // output, as that fusion output would have a tuple op as user, which we // do not allow here. @@ -151,9 +174,12 @@ std::optional FusionCanShareBufferHint(const HloInstruction* user, return false; } } + if (hlo->IsRoot()) { + ++reached_root; + } } } - return found_path_to_output; + return found_path_to_output && (user_index.empty() || reached_root == 1); } std::optional CanShareBufferHint(const HloInstruction* user, diff --git a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc index dc059e40c53580..e3e6425b022bb6 100644 --- a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc @@ -201,13 +201,14 @@ fused_computation { param_1.1 = f32[2,3]{1,0} parameter(1) neg = f32[2,3]{1,0} negate(param_1.1) mul = f32[2,3]{1,0} multiply(param_0.1, neg) - ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}) tuple(mul, neg) + transpose = f32[3,2]{1,0} transpose(neg), dimensions={1,0} + ROOT tuple = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) tuple(mul, neg, transpose) } ENTRY main { param_0 = f32[2,3]{1,0} parameter(0) param_1 = f32[2,3]{1,0} parameter(1) - ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation + ROOT fusion = (f32[2,3]{1,0}, f32[2,3]{1,0}, f32[3,2]{1,0}) fusion(param_0, param_1), kind=kLoop, calls=fused_computation } )"; @@ -216,7 +217,7 @@ ENTRY main { HloInstruction* fusion = module->entry_computation()->root_instruction(); ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); // The second operand cannot share the buffer with the second fusion output, - // because the 'neg' op is also used on the path to the first fusion output. + // because the 'neg' op is also used by a non-elementwise op. ExpectOptionalFalse( FusionCanShareBufferHint(fusion, fusion->operand(1), {1})); // The first operand cannot share the buffer with the second fusion output, @@ -225,6 +226,39 @@ ENTRY main { FusionCanShareBufferHint(fusion, fusion->operand(0), {1})); } +TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedReductionEmitter) { + constexpr char kModuleString[] = R"( +HloModule TestModule + +%maximum { + %lhs = f32[] parameter(0) + %rhs = f32[] parameter(1) + ROOT %res = f32[] maximum(%lhs, %rhs) +} + +%fused_computation { + %lhs = f32[3,40] parameter(0) + %rhs = f32[3,40] parameter(1) + %add = f32[3,40] add(%lhs, %rhs) + %bc = f32[120] bitcast(%add) + %init = f32[] constant(-inf) + %max = f32[] reduce(%bc, %init), dimensions={0}, to_apply=%maximum + ROOT %result = (f32[], f32[3,40]) tuple(%max, %add) +} + +ENTRY %main { + %lhs = f32[3,40] parameter(0) + %rhs = f32[3,40] parameter(1) + ROOT %fusion = (f32[], f32[3,40]) fusion(%lhs, %rhs), + kind=kLoop, calls=%fused_computation +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kModuleString)); + HloInstruction* fusion = module->entry_computation()->root_instruction(); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {1})); +} + TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedScatterFusion) { const char* const kModuleString = R"( HloModule fusion From ba6f57b24248a26e862743c08573db089e517ae9 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 24 Oct 2023 04:43:55 -0700 Subject: [PATCH 039/246] Mark mnist_train_test as requiring network access. PiperOrigin-RevId: 576099007 --- tensorflow/compiler/mlir/tfr/examples/mnist/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD index ee79b37a73d1cc..135bc20b970bc5 100644 --- a/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD +++ b/tensorflow/compiler/mlir/tfr/examples/mnist/BUILD @@ -81,6 +81,7 @@ distribute_py_strict_test( "notap", # The test is too long to run as part of llvm presubmits (b/173661843). "notpu", # Takes too long (b/192305423) "notsan", # Not needed, and there were issues with timeouts. + "requires-net:external", ], # TODO(b/175056184): Re-enable xla_enable_strict_auto_jit once the issues From 9d3212e2464f803f0f241a8b74c11ff94c36714c Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 24 Oct 2023 07:40:39 -0700 Subject: [PATCH 040/246] Integrate LLVM at llvm/llvm-project@eb67b34740b3 Updates LLVM usage to match [eb67b34740b3](https://github.com/llvm/llvm-project/commit/eb67b34740b3) PiperOrigin-RevId: 576135571 --- third_party/llvm/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/llvm/workspace.bzl b/third_party/llvm/workspace.bzl index de644ff01f52d2..5e8b072c63dd6c 100644 --- a/third_party/llvm/workspace.bzl +++ b/third_party/llvm/workspace.bzl @@ -4,8 +4,8 @@ load("//third_party:repo.bzl", "tf_http_archive") def repo(name): """Imports LLVM.""" - LLVM_COMMIT = "e214bdac51c46b554e1fb99de6b6b6c735b75bf1" - LLVM_SHA256 = "0da83d809440c4d70061656ad5868ddc60c522fc5d64e08c36118416642ffc43" + LLVM_COMMIT = "eb67b34740b37909f1b213fca6c564257577be25" + LLVM_SHA256 = "8f8fd78c23e9b61b9ba3db2bdfef40ecffe83ed08222e52b0a0ea8ee2717ac38" tf_http_archive( name = name, From e863f7fb6e68c0e8c0f8a3c3b7571b0ad42d0fc9 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 24 Oct 2023 08:42:45 -0700 Subject: [PATCH 041/246] Support MHLO serialization of CustomCall with typed_ffi PiperOrigin-RevId: 576152547 --- .../hlo_legalize_to_stablehlo.cc | 17 +++---- .../stablehlo_legalize_to_hlo.cc | 46 ++++++++++++------- ...lo-legalize-to-stablehlo-experimental.mlir | 17 ------- .../mhlo/hlo-legalize-to-stablehlo.mlir | 17 +++++++ .../mhlo/stablehlo-legalize-to-hlo.mlir | 5 +- 5 files changed, 59 insertions(+), 43 deletions(-) diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc index 166b13284302d1..4774aac93d4b56 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/hlo_legalize_to_stablehlo/hlo_legalize_to_stablehlo.cc @@ -107,13 +107,6 @@ bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) { // Proposal: https://github.com/openxla/stablehlo/issues/742. if (hasPackedNibble(hloOp.getPrecisionConfig())) return true; } - if constexpr (std::is_same::value) { - // StableHLO CustomCall doesn't support API_VERSION_TYPED_FFI yet. - // Proposal: https://github.com/openxla/stablehlo/issues/637. - if (hloOp.getApiVersion() == - mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) - return true; - } if constexpr (std::is_same::value) { // StableHLO DotGeneral doesn't support PACKED_NIBBLE yet. // Proposal: https://github.com/openxla/stablehlo/issues/742. @@ -132,13 +125,21 @@ bool hasExperimentalFeaturesNotInStablehlo(HloOpTy hloOp) { // frontends but are not yet part of StableHLO. Such features might be a good // fit for StableHLO, and are usually accompanied by a StableHLO GitHub ticket. template -std::optional getPublicFeaturesNotInStablehlo(HloOpTy) { +std::optional getPublicFeaturesNotInStablehlo(HloOpTy hloOp) { // StableHLO doesn't support TanOp yet. // Proposal: https://github.com/openxla/stablehlo/issues/954 if constexpr (std::is_same::value) { // Version 1: Initial version for TanOp. return 1; } + // StableHLO CustomCall doesn't support API_VERSION_TYPED_FFI yet. + // Proposal: https://github.com/openxla/stablehlo/issues/637. + if constexpr (std::is_same::value) { + // Version 1: Initial version for TYPED_FFI + if (hloOp.getApiVersion() == + mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI) + return 1; + } return std::nullopt; } diff --git a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc index 3675c705053251..a3c401a7b1abe1 100644 --- a/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc +++ b/third_party/xla/xla/mlir_hlo/mhlo/transforms/stablehlo_legalize_to_hlo/stablehlo_legalize_to_hlo.cc @@ -263,6 +263,34 @@ LogicalResult rewriteCustomCallAsMhloOp(stablehlo::CustomCallOp stablehloOp, return success(); } +// Preserve backward compatibility of typed_ffi custom calls by converting: +// `stablehlo.custom_call @foo(%arg0) { mhlo.backend_config = {...} }` +// ==> +// `mhlo.custom_call @foo(%arg0) { backend_config = {...}, api_version = 4}` +// +// Fails if StableHLO op has non-empty backend_config, or uses API version +// other than API_VERSION_ORIGINAL. +LogicalResult fixupMhloBackendConfig(stablehlo::CustomCallOp stablehloOp, + mhlo::CustomCallOp hloOp) { + auto stablehloBackendConfig = stablehloOp->getAttr("mhlo.backend_config"); + if (stablehloBackendConfig) { + if (auto oldHloBackendConfig = + hloOp.getBackendConfigAttr() + .template dyn_cast_or_null()) { + if (!oldHloBackendConfig.empty()) return failure(); + } else { + return failure(); + } + if (stablehloOp.getApiVersion() != + stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL) + return failure(); + + hloOp.setBackendConfigAttr(stablehloBackendConfig); + hloOp.setApiVersion(mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI); + } + return success(); +} + template class StablehloToHloOpConverter : public OpConversionPattern { public: @@ -323,23 +351,9 @@ class StablehloToHloOpConverter : public OpConversionPattern { stablehloOp, hloTypes, hloOperands, hloAttrs); } + // For backward compatibility, fix custom call with mhlo.backend_config if constexpr (std::is_same::value) { - auto stablehloBackendConfig = stablehloOp->getAttr("mhlo.backend_config"); - if (stablehloBackendConfig) { - if (auto oldHloBackendConfig = - hloOp.getBackendConfigAttr() - .template dyn_cast_or_null()) { - if (oldHloBackendConfig != "") return failure(); - } else { - return failure(); - } - if (stablehloOp.getApiVersion() != - stablehlo::CustomCallApiVersion::API_VERSION_ORIGINAL) - return failure(); - - hloOp.setBackendConfigAttr(stablehloBackendConfig); - hloOp.setApiVersion(mhlo::CustomCallApiVersion::API_VERSION_TYPED_FFI); - } + if (failed(fixupMhloBackendConfig(stablehloOp, hloOp))) return failure(); } // Finally, populate the regions while converting argument types diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir index fb1a9e041f49f0..ed2d07b367abc0 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo-experimental.mlir @@ -42,23 +42,6 @@ func.func @op_all_to_all_tuple(%arg0: tensor<128x4xf32>, %arg1: tensor<128x4xf32 // ----- -// CHECK-LABEL: "op_custom_call_api_version_typed_ffi" -func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { - // CHECK: "stablehlo.custom_call"(%arg0) { - // CHECK-SAME: call_target_name = "mhlo.custom_call" - // CHECK-SAME: mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"} - // CHECK-SAME: } : (tensor) -> tensor - // expected-error@+1 {{failed to legalize operation 'mhlo.custom_call' that was explicitly marked illegal}} - %0 = "mhlo.custom_call"(%arg0) { - call_target_name = "foo", - backend_config = {foo = "bar"}, - api_version = 4 : i32 - } : (tensor) -> tensor - return %0 : tensor -} - -// ----- - // CHECK-LABEL: "attr_precision_packed_nibble" func.func @attr_precision_packed_nibble(%arg0: tensor<8x16xf32>, %arg1: tensor<16x8xf32>) -> tensor<8x8xf32> { // CHECK: "stablehlo.custom_call"(%arg0, %arg1) { diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir index 0ad2e0165b76ad..6de9c1a64b37b7 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/hlo-legalize-to-stablehlo.mlir @@ -152,6 +152,23 @@ func.func @attr_custom_call_api_version_status_returning_unified(%arg0: tensor } +// ----- + +// CHECK-LABEL: "attr_custom_call_api_version_typed_ffi" +func.func @attr_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor { + // CHECK: "stablehlo.custom_call"(%arg0) { + // CHECK-SAME: call_target_name = "mhlo.custom_call" + // CHECK-SAME: mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"}, + // CHECK-SAME: mhlo.version = 1 : i64 + // CHECK-SAME: } : (tensor) -> tensor + %0 = "mhlo.custom_call"(%arg0) { + call_target_name = "foo", + backend_config = {foo = "bar"}, + api_version = 4 : i32 + } : (tensor) -> tensor + return %0 : tensor +} + // CustomCallSchedule aka #mhlo is unsupported at the moment (see negative test below). // DequantizeMode aka #mhlo is unused at the moment. // DomainKind aka #mhlo is unsupported at the moment (see negative test below). diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir index d1dc28fed0e868..af8419890635ab 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir +++ b/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/stablehlo-legalize-to-hlo.mlir @@ -715,7 +715,8 @@ func.func @op_custom_call_api_version_typed_ffi(%arg0: tensor) -> tensor) -> tensor %0 = "stablehlo.custom_call"(%arg0) { call_target_name = "mhlo.custom_call", - mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"} + mhlo.attributes = {api_version = 4 : i32, backend_config = {foo = "bar"}, call_target_name = "foo"}, + mhlo.version = 1 : i64 } : (tensor) -> tensor return %0 : tensor } @@ -729,7 +730,7 @@ func.func @op_custom_call_mhlo_backend_config(%arg0: tensor<16x256xbf16>) -> ten // CHECK-SAME: } : (tensor<16x256xbf16>) -> tensor<16x4xbf16> %4 = stablehlo.custom_call @foo(%arg0) { "mhlo.backend_config" = {aggregate_to_topk = true} - } : (tensor<16x256xbf16>) -> tensor<16x4xbf16> + } : (tensor<16x256xbf16>) -> tensor<16x4xbf16> return %4 : tensor<16x4xbf16> } From 3171da22964e4735f4184a9742d12d14715014ce Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Tue, 24 Oct 2023 09:15:16 -0700 Subject: [PATCH 042/246] Roll forward with fix cd0d7b5 "[NVIDIA XLA GPU] Fix allreduce/reducescatter combiner bug #6126" Update creations of collective ops to maintain the correct pointers to their unique `collective_call_instruction_`. Clone computation if needed. PiperOrigin-RevId: 576161662 --- .../xla/xla/service/all_reduce_promotion.cc | 1 + .../xla/xla/service/float_normalization.cc | 11 ++ .../xla/service/float_normalization_test.cc | 163 ++++++++++++++++++ .../xla/service/reduce_scatter_decomposer.cc | 6 +- .../xla/xla/service/spmd/spmd_partitioner.cc | 29 +++- 5 files changed, 200 insertions(+), 10 deletions(-) diff --git a/third_party/xla/xla/service/all_reduce_promotion.cc b/third_party/xla/xla/service/all_reduce_promotion.cc index 00469a2a6c9e78..30965128a81524 100644 --- a/third_party/xla/xla/service/all_reduce_promotion.cc +++ b/third_party/xla/xla/service/all_reduce_promotion.cc @@ -49,6 +49,7 @@ std::unique_ptr CloneAllReduce( return inst->GetModule()->AddEmbeddedComputation(promoted.Build()); }(); new_inst->set_to_apply(to_apply_promoted); + to_apply_promoted->SetCollectiveCallInstruction(new_inst.get()); return new_inst; } diff --git a/third_party/xla/xla/service/float_normalization.cc b/third_party/xla/xla/service/float_normalization.cc index 267a92383aab17..84774a3b4884ed 100644 --- a/third_party/xla/xla/service/float_normalization.cc +++ b/third_party/xla/xla/service/float_normalization.cc @@ -335,6 +335,9 @@ Status FloatNormalizationVisitor::HandleMultipleOutputs(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { + if (comp->IsCollectiveCalledComputation()) { + continue; + } bool comp_has_low_precision = false; if (comp->root_instruction()->shape().element_type() == HighPrecisionType()) { @@ -411,6 +414,9 @@ Status FloatNormalizationVisitor::HandleInstruction(HloInstruction* hlo) { std::vector low_precision_called_comps; for (auto* comp : hlo->called_computations()) { + if (comp->IsCollectiveCalledComputation()) { + continue; + } bool comp_has_low_precision = false; high_prec_count += CountSubshapesWithMatchingType( comp->root_instruction()->shape(), HighPrecisionType()); @@ -549,6 +555,11 @@ StatusOr FloatNormalization::Run( ", before:\n" + module->ToString()); FloatNormalizationVisitor visitor(float_support_, this); for (auto* comp : module->MakeComputationPostOrder(execution_threads)) { + if (comp->IsCollectiveCalledComputation()) { + XLA_VLOG_LINES(2, "Skip processing collective called computation: " + + comp->ToString()); + continue; + } TF_RETURN_IF_ERROR(comp->Accept(&visitor)); } XLA_VLOG_LINES(2, "FloatNormalization::Run() for " + diff --git a/third_party/xla/xla/service/float_normalization_test.cc b/third_party/xla/xla/service/float_normalization_test.cc index 3a41960bad932e..2d6a976ff59df4 100644 --- a/third_party/xla/xla/service/float_normalization_test.cc +++ b/third_party/xla/xla/service/float_normalization_test.cc @@ -76,6 +76,38 @@ class TestFloatSupport : public FloatSupport { } }; +// The test float class that doesn't support any compute ops for low-precision +// but supports some collectives. +class TestFloatNoComputeSupport : public FloatSupport { + public: + explicit TestFloatNoComputeSupport(PrimitiveType low_precision_type) + : FloatSupport(low_precision_type) {} + ~TestFloatNoComputeSupport() override = default; + + bool SupportsLowPrecisionOperand(const HloInstruction& hlo, + int64_t operand_index) const override { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll || + hlo.opcode() == HloOpcode::kAllReduce || + hlo.opcode() == HloOpcode::kReduceScatter) { + return true; + } + return false; + } + + bool SupportsLowPrecisionOutput(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement || + hlo.opcode() == HloOpcode::kAllToAll || + hlo.opcode() == HloOpcode::kAllReduce || + hlo.opcode() == HloOpcode::kReduceScatter) { + return true; + } + return false; + } +}; + class FloatNormalizationTest : public HloTestBase { protected: FloatNormalizationTest() @@ -485,4 +517,135 @@ TEST_F(FloatNormalizationTest, ResolveIfUnsupportedF8e5m2) { EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); } +class FloatNormalizationNoComputeSupportTest : public FloatNormalizationTest { + protected: + bool Normalize(HloModule* module, PrimitiveType low_precision_type = BF16) { + TestFloatNoComputeSupport float_support(low_precision_type); + FloatNormalization normalization(&float_support); + + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + EXPECT_IS_OK(verifier.Run(module).status()); + + return result.value(); + } +}; + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyMultiOuputAllReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_shape_b = ShapeUtil::MakeShape(BF16, {16, 16}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape_b, "b")); + + HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateAllReduce( + ShapeUtil::MakeTupleShape({bf16_shape_a, bf16_shape_b}), {a, b}, + reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape_b, crs, 1)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), BF16); +} + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyAllReduce) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + HloInstruction* crs = builder.AddInstruction( + HloInstruction::CreateAllReduce(bf16_shape_a, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); +} + +TEST_F(FloatNormalizationNoComputeSupportTest, + NoNormalizationForToApplyReduceScatter) { + auto module = CreateNewVerifiedModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(BF16, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(BF16, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(BF16, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + Shape bf16_shape_a = ShapeUtil::MakeShape(BF16, {2, 4}); + Shape bf16_shape_scattered = ShapeUtil::MakeShape(BF16, {1, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_shape_a, "a")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateReduceScatter( + bf16_shape_scattered, {a}, reduction, + /*replica_groups=*/{}, + /*constrain_layout=*/false, + /*channel_id=*/std::nullopt, + /*use_global_device_ids=*/false, /*scatter_dimension*/ 0)); + + auto computation = module->AddEntryComputation(builder.Build()); + // Since we skip processing to_apply region, nothing should change in the + // original HLO. + EXPECT_FALSE(Normalize(module.get())); + EXPECT_EQ(computation->root_instruction()->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(0)->shape().element_type(), BF16); + EXPECT_EQ(crs->to_apply()->root_instruction()->opcode(), HloOpcode::kAdd); +} + } // namespace xla diff --git a/third_party/xla/xla/service/reduce_scatter_decomposer.cc b/third_party/xla/xla/service/reduce_scatter_decomposer.cc index 1fb197bcea8b82..59366639b84e2f 100644 --- a/third_party/xla/xla/service/reduce_scatter_decomposer.cc +++ b/third_party/xla/xla/service/reduce_scatter_decomposer.cc @@ -55,11 +55,15 @@ StatusOr ReduceScatterDecomposer::Run( } // Create an all-reduce + HloComputation *apply_clone = module->AddComputationAndUnifyNamesAndIds( + rs->to_apply()->Clone(), /*is_entry=*/false); HloInstruction *ar = computation->AddInstruction(HloInstruction::CreateAllReduce( - rs->operand(0)->shape(), rs->operands(), rs->to_apply(), + rs->operand(0)->shape(), rs->operands(), apply_clone, rs->replica_groups(), rs->constrain_layout(), channel_id, rs->use_global_device_ids())); + apply_clone->SetCollectiveCallInstruction(ar); + // Create start indices for a dynamic slice to decompose the all-reduce // results. TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/service/spmd/spmd_partitioner.cc b/third_party/xla/xla/service/spmd/spmd_partitioner.cc index 429181dbecfe70..98438aa08ae251 100644 --- a/third_party/xla/xla/service/spmd/spmd_partitioner.cc +++ b/third_party/xla/xla/service/spmd/spmd_partitioner.cc @@ -48,7 +48,6 @@ limitations under the License. #include "xla/service/hlo_cse.h" #include "xla/service/hlo_dce.h" #include "xla/service/hlo_pass_pipeline.h" -#include "xla/service/pattern_matcher.h" #include "xla/service/shape_inference.h" #include "xla/service/spmd/custom_call_handler.h" #include "xla/service/spmd/spmd_partitioner_util.h" @@ -4730,10 +4729,16 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, for (int64_t i = 0; i < num_replicas; ++i) { groups[i].add_replica_ids(i); } - return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/false)); + HloComputation* reduction_clone = + reduction->parent()->AddComputationAndUnifyNamesAndIds( + reduction->Clone(), false); + HloInstruction* all_reduce = + b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction_clone, groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/false)); + reduction_clone->SetCollectiveCallInstruction(all_reduce); + return all_reduce; } std::vector device_groups; @@ -4746,10 +4751,16 @@ SPMDCollectiveOpsCreator GetDefaultCollectiveOpsCreator(int64_t num_partitions, } } } - return b->AddInstruction(HloInstruction::CreateAllReduce( - operand->shape(), {operand}, reduction, device_groups, - /*constrain_layout=*/false, channel_id, - /*use_global_device_ids=*/true)); + HloComputation* reduction_clone = + reduction->parent()->AddComputationAndUnifyNamesAndIds( + reduction->Clone(), false); + HloInstruction* all_reduce = + b->AddInstruction(HloInstruction::CreateAllReduce( + operand->shape(), {operand}, reduction_clone, device_groups, + /*constrain_layout=*/false, channel_id, + /*use_global_device_ids=*/true)); + reduction_clone->SetCollectiveCallInstruction(all_reduce); + return all_reduce; }, [num_partitions](SpmdBuilder* b, HloInstruction* operand, std::vector>& src_dst_pairs, From 4ad796cbea378f435c23926b492e940e29f59720 Mon Sep 17 00:00:00 2001 From: Juan Martinez Castellanos Date: Tue, 24 Oct 2023 09:41:38 -0700 Subject: [PATCH 043/246] Redirect more references away from the `lib/io:lib` target and onto the new single-source-file targets. PiperOrigin-RevId: 576169652 --- tensorflow/python/BUILD | 8 ++++++-- tensorflow/python/framework/BUILD | 4 ++-- tensorflow/python/platform/BUILD | 1 - 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 210787452b9752..e3a27909cac23f 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -189,7 +189,9 @@ py_strict_library( "//tensorflow/python/grappler:tf_cluster", "//tensorflow/python/grappler:tf_item", "//tensorflow/python/grappler:tf_optimizer", - "//tensorflow/python/lib/io:lib", + "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/lib/io:python_io", + "//tensorflow/python/lib/io:tf_record", "//tensorflow/python/module", "//tensorflow/python/ops:array_ops", "//tensorflow/python/ops:array_ops_stack", @@ -386,7 +388,9 @@ py_strict_library( "//tensorflow/python/framework:ops", "//tensorflow/python/framework:test_combinations_lib", "//tensorflow/python/framework:versions", - "//tensorflow/python/lib/io:lib", + "//tensorflow/python/lib/io:file_io", + "//tensorflow/python/lib/io:python_io", + "//tensorflow/python/lib/io:tf_record", "//tensorflow/python/module", "//tensorflow/python/ops:audio_ops_gen", "//tensorflow/python/ops:bincount_ops", diff --git a/tensorflow/python/framework/BUILD b/tensorflow/python/framework/BUILD index 3e32c7f86c66e6..761b9868859aff 100644 --- a/tensorflow/python/framework/BUILD +++ b/tensorflow/python/framework/BUILD @@ -1929,7 +1929,7 @@ pytype_strict_library( "//tensorflow/core:protos_all_py", "//tensorflow/python/client:pywrap_tf_session", "//tensorflow/python/eager:context", - "//tensorflow/python/lib/io:lib", + "//tensorflow/python/lib/io:file_io", "//tensorflow/python/platform:tf_logging", "//tensorflow/python/util:compat", "@pypi_packaging//:pkg", @@ -1963,7 +1963,7 @@ pytype_strict_library( deps = [ ":byte_swap_tensor", ":ops", - "//tensorflow/python/lib/io:lib", + "//tensorflow/python/lib/io:file_io", "//tensorflow/python/util:tf_export", ], ) diff --git a/tensorflow/python/platform/BUILD b/tensorflow/python/platform/BUILD index b3a7f6400d87f0..9a29df190d4645 100644 --- a/tensorflow/python/platform/BUILD +++ b/tensorflow/python/platform/BUILD @@ -91,7 +91,6 @@ py_strict_library( "@absl_py//absl:app", "@absl_py//absl/testing:absltest", "//tensorflow/python/framework:errors", - "//tensorflow/python/lib/io:lib", "//tensorflow/python/util:tf_decorator", "//tensorflow/python/util:tf_inspect", ]), From cd72dc78c6248369b7e06572a715e633d3bbc008 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2023 10:12:01 -0700 Subject: [PATCH 044/246] Fix statistics string from Gb/s to Gbyte/s. PiperOrigin-RevId: 576179616 --- tensorflow/lite/delegates/gpu/common/task/profiling_info.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc b/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc index 630290f6bc3260..0c9d3ad866ba44 100644 --- a/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc +++ b/tensorflow/lite/delegates/gpu/common/task/profiling_info.cc @@ -49,7 +49,7 @@ std::string ProfilingInfo::GetDetailedReport() const { dispatch.read_mem_size + dispatch.write_mem_size; const double giga_bytes = total_size / 1024.0 / 1024.0 / 1024.0; const double giga_bytes_per_sec = times_per_sec * giga_bytes; - result += ", " + std::to_string(giga_bytes_per_sec) + " Gb/s"; + result += ", " + std::to_string(giga_bytes_per_sec) + " Gbyte/s"; } if (dispatch.flops) { const double giga_flops = dispatch.flops / 1000.0 / 1000.0 / 1000.0; From 453c7a6d3fac08316a353aee89120f523c797153 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 24 Oct 2023 10:29:38 -0700 Subject: [PATCH 045/246] Do not allow more than 2^31 rows in a sparse core stacked embedding. PiperOrigin-RevId: 576186428 --- .../core/tpu/kernels/sparse_core_layout.cc | 6 +++- .../core/tpu/kernels/sparse_core_layout.h | 8 ++++- .../tpu/kernels/sparse_core_layout_test.cc | 29 +++++++++++++++++++ 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.cc b/tensorflow/core/tpu/kernels/sparse_core_layout.cc index ff1bdcac063e9c..8243cc04622c4f 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.cc @@ -113,6 +113,11 @@ absl::Status SparseCoreLayoutStacker::AddTable(tsl::StringPiece table_name, continue; } + if (row_limit_ != 0 && + ts.unsharded_height + padded_height >= row_limit_) { + continue; + } + // We found a stack we can put it in. stack = &ts; break; @@ -158,7 +163,6 @@ absl::Status SparseCoreLayoutStacker::AddTable(tsl::StringPiece table_name, stack->total_variable_shard_bytes += variable_shard_bytes; stack->total_activation_mem_bytes += activation_mem_bytes; - ++num_tables_; return absl::OkStatus(); } diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout.h b/tensorflow/core/tpu/kernels/sparse_core_layout.h index 5ee44a113f1ac1..5c9c15dad532b8 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout.h +++ b/tensorflow/core/tpu/kernels/sparse_core_layout.h @@ -54,6 +54,10 @@ class SparseCoreLayoutStacker { CHECK(stacks_by_group_.empty()) << "must call before AddTable"; stacking_enabled_ = stacking_enabled; } + void SetStackingRowLimit(int64_t row_limit) { + CHECK(stacks_by_group_.empty()) << "must call before AddTable"; + row_limit_ = row_limit; + } // Add a new table. Arguments: // table_name: How this table will be referred to. @@ -102,7 +106,9 @@ class SparseCoreLayoutStacker { bool stacking_enabled_ = true; int64_t activation_mem_bytes_limit_ = 0; int64_t variable_shard_bytes_limit_ = 0; - int num_tables_ = 0; + // Sparse core ops use signed int for row numbers so we had better not stack + // beyond this limit. + int64_t row_limit_ = (1LL << 31) - 1; // All the stacks that we currently know about. Note that we use a btree_map // rather than a flat_hash_map so the resulting order is deterministic as long diff --git a/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc b/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc index aeba71e787b86d..883561cb4abf45 100644 --- a/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc +++ b/tensorflow/core/tpu/kernels/sparse_core_layout_test.cc @@ -131,6 +131,35 @@ TEST(SparseCoreLayoutStacker, RespectsVariableShardLimit) { )pb")))); } +TEST(SparseCoreLayoutStacker, RespectsRowLimit) { + SparseCoreLayoutStacker stacker(2); + // Disable the other limits. + stacker.SetActivationMemoryBytesLimit(0); + stacker.SetVariableShardBytesLimit(0); + + // Here there are several identical tables that contribute 2^30 rows. Since + // the default row limit is 2^31-1, they should not be able to stack. + ASSERT_OK(stacker.AddTable("table1", 1 << 29, 8, "stack1", 1024)); + ASSERT_OK(stacker.AddTable("table2", 1 << 29, 8, "stack1", 1024)); + ASSERT_OK(stacker.AddTable("table3", 1 << 29, 8, "stack1", 1024)); + ASSERT_OK(stacker.AddTable("table4", 1 << 29, 8, "stack1", 1024)); + EXPECT_THAT(stacker.GetLayouts(), IsOkAndHolds(Partially(EqualsProto(R"pb( + tables { + table_name: 'table1' + stacked_table_name: 'table1_table2_table3' + } + tables { + table_name: 'table2' + stacked_table_name: 'table1_table2_table3' + } + tables { + table_name: 'table3' + stacked_table_name: 'table1_table2_table3' + } + tables { table_name: 'table4' stacked_table_name: 'table4' } + )pb")))); +} + } // namespace } // namespace tpu } // namespace tensorflow From 84e77213965f4eedb36aa711e41908a704854e38 Mon Sep 17 00:00:00 2001 From: Derek Murray Date: Tue, 24 Oct 2023 10:35:18 -0700 Subject: [PATCH 046/246] Remove operators that were accidentally included in the top-level API. PiperOrigin-RevId: 576188602 --- tensorflow/core/api_def/BUILD | 5 +- .../base_api/api_def_ConvertToCooTensor.pbtxt | 4 ++ ...etMinibatchSplitsWithPhysicalReplica.pbtxt | 4 ++ ...tMinibatchesInCsrWithPhysicalReplica.pbtxt | 4 ++ ...pi_def_StoreMinibatchStatisticsInFdo.pbtxt | 4 ++ ...f_TPUAnnotateTensorsWithDynamicShape.pbtxt | 4 ++ .../api_def_TPUCopyWithDynamicShape.pbtxt | 8 +++ .../api_def_XlaSparseCoreAdagrad.pbtxt | 4 ++ ...api_def_XlaSparseCoreAdagradMomentum.pbtxt | 4 ++ .../base_api/api_def_XlaSparseCoreAdam.pbtxt | 4 ++ .../base_api/api_def_XlaSparseCoreFtrl.pbtxt | 4 ++ .../base_api/api_def_XlaSparseCoreSgd.pbtxt | 4 ++ .../api_def_XlaSparseDenseMatmul.pbtxt | 4 ++ ...enseMatmulGradWithAdagradAndCsrInput.pbtxt | 4 ++ ...ulGradWithAdagradMomentumAndCsrInput.pbtxt | 4 ++ ...seDenseMatmulGradWithAdamAndCsrInput.pbtxt | 4 ++ ...seDenseMatmulGradWithFtrlAndCsrInput.pbtxt | 4 ++ ...rseDenseMatmulGradWithSgdAndCsrInput.pbtxt | 4 ++ ...def_XlaSparseDenseMatmulWithCsrInput.pbtxt | 4 ++ .../api_def_ConvertToCooTensor.pbtxt | 4 ++ ...etMinibatchSplitsWithPhysicalReplica.pbtxt | 4 ++ ...tMinibatchesInCsrWithPhysicalReplica.pbtxt | 4 ++ ...pi_def_StoreMinibatchStatisticsInFdo.pbtxt | 4 ++ ...f_TPUAnnotateTensorsWithDynamicShape.pbtxt | 4 ++ .../api_def_TPUCopyWithDynamicShape.pbtxt | 4 ++ .../api_def_XlaSparseCoreAdagrad.pbtxt | 4 ++ ...api_def_XlaSparseCoreAdagradMomentum.pbtxt | 4 ++ .../api_def_XlaSparseCoreAdam.pbtxt | 4 ++ .../api_def_XlaSparseCoreFtrl.pbtxt | 4 ++ .../python_api/api_def_XlaSparseCoreSgd.pbtxt | 4 ++ .../api_def_XlaSparseDenseMatmul.pbtxt | 4 ++ ...enseMatmulGradWithAdagradAndCsrInput.pbtxt | 4 ++ ...ulGradWithAdagradMomentumAndCsrInput.pbtxt | 4 ++ ...seDenseMatmulGradWithAdamAndCsrInput.pbtxt | 4 ++ ...seDenseMatmulGradWithFtrlAndCsrInput.pbtxt | 4 ++ ...rseDenseMatmulGradWithSgdAndCsrInput.pbtxt | 4 ++ ...def_XlaSparseDenseMatmulWithCsrInput.pbtxt | 4 ++ .../tpu/ops/tpu_copy_with_dynamic_shape_op.cc | 6 +- tensorflow/python/tpu/ops/BUILD | 4 ++ .../tools/api/golden/v1/tensorflow.pbtxt | 68 ------------------- .../tools/api/golden/v2/tensorflow.pbtxt | 68 ------------------- 41 files changed, 157 insertions(+), 142 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagrad.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseCoreAdam.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseCoreFtrl.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseCoreSgd.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmul.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/base_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_ConvertToCooTensor.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_TPUCopyWithDynamicShape.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagrad.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdagradMomentum.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseCoreAdam.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseCoreFtrl.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseCoreSgd.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmul.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt diff --git a/tensorflow/core/api_def/BUILD b/tensorflow/core/api_def/BUILD index 4884972676b864..b3efda8f4791b7 100644 --- a/tensorflow/core/api_def/BUILD +++ b/tensorflow/core/api_def/BUILD @@ -6,12 +6,12 @@ # :python_api_def # :java_api_def -load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", ) +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") load( "//third_party/mkl:build_defs.bzl", "if_mkl", @@ -116,5 +116,8 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:resource_loader", + "//tensorflow/core/tpu/ops:sparse_core_ops", + "//tensorflow/core/tpu/ops:sparse_core_preprocess_ops", + "//tensorflow/core/tpu/ops:tpu_copy_with_dynamic_shape_op", ], ) diff --git a/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt new file mode 100644 index 00000000000000..c9d629514f49f7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ConvertToCooTensor.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ConvertToCooTensor" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt new file mode 100644 index 00000000000000..e402d2bf92de67 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GetMinibatchSplitsWithPhysicalReplica" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt b/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt new file mode 100644 index 00000000000000..49493ee2254354 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "GetMinibatchesInCsrWithPhysicalReplica" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt b/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt new file mode 100644 index 00000000000000..9545ffdf9c4b7e --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StoreMinibatchStatisticsInFdo.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StoreMinibatchStatisticsInFdo" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt new file mode 100644 index 00000000000000..84ac58c2e7e5ea --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TPUAnnotateTensorsWithDynamicShape" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt new file mode 100644 index 00000000000000..423e1a22244a4b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_TPUCopyWithDynamicShape.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "TPUCopyWithDynamicShape" + visibility: HIDDEN + summary: <